Compare commits

...

114 Commits

Author SHA1 Message Date
Awni Hannun
67ec27d515 synch before reading memory in test 2025-04-07 14:37:32 -07:00
Awni Hannun
f140792f1c more donation take 2 2025-04-07 14:37:32 -07:00
Awni Hannun
35dc8580e3 fix events, task limit cpu 2025-04-07 14:37:32 -07:00
Awni Hannun
392bca56a6 fix perf regression, remove unecessary abort 2025-04-07 14:37:32 -07:00
Awni Hannun
60c4154346 Only request residency once (#2051) 2025-04-07 10:47:51 -07:00
Awni Hannun
f2c85308c1 add a half simd gemm fallback (#2046)
* add a half simd gemm fallback

* nit
2025-04-07 09:31:29 -07:00
Awni Hannun
1a28b69ee2 only add to residency set once (#2049) 2025-04-06 17:38:25 -07:00
Cheng
ba09f01ce8 Remove test of converting negative float to uint (#2048) 2025-04-06 06:21:46 -07:00
Cheng
6cf48872b7 wait_for_one should wait for task to finish (#2047) 2025-04-05 20:05:16 -07:00
Angelos Katharopoulos
7b3b8fa000 Fix ci release (#2045) 2025-04-04 20:25:01 -07:00
Awni Hannun
ec5e2aae61 nit in doc (#2044) 2025-04-04 12:04:17 -07:00
Awni Hannun
86389bf970 patch bump (#2043) 2025-04-03 13:15:18 -07:00
Jagrit Digani
3290bfa690 Add new sdpa function overload (#2035)
* Add new sdpa function overload

* Address comments

* Remove std::varaint from cpp sdpa function
2025-04-03 11:58:28 -07:00
Jagrit Digani
8777fd104f Depthwise Conv2D optimization (#2036)
- Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions
- Add related tests
2025-04-03 09:42:04 -07:00
Awni Hannun
c41f7565ed fix softmax / logsumexp (#2042) 2025-04-03 08:32:59 -07:00
Awni Hannun
9ba81e3da4 tune quant dispatch (#2031) 2025-04-02 20:05:54 -07:00
Awni Hannun
c23888acd7 Fix build warning (#2033) 2025-04-01 14:42:27 -07:00
Awni Hannun
f98ce25ab9 fix residency set for real (#2032) 2025-04-01 12:59:48 -07:00
Awni Hannun
de5f38fd48 Custom logsumexp (#2028)
* initial custom logsumexp

* more tests

* comments + fix
2025-03-31 07:36:55 -07:00
Angelos Katharopoulos
ec2854b13a Swap -inf for finite_minimum value (#2029) 2025-03-30 21:55:04 -07:00
Stephen Panaro
90823d2938 Add missing funcs to docs (#2021) 2025-03-30 18:29:33 -07:00
Jesper Stemann Andersen
5f5770e3a2 Fix CPU sign for unsigned ints (#2024)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-03-30 17:56:59 -07:00
Awni Hannun
28f39e9038 Log for complex numbers in Metal (#2025)
* Log for complex numbers in Metal

* fix log2
2025-03-30 17:04:38 -07:00
Awni Hannun
b2d2b37888 fix residency set clearing (#2027) 2025-03-30 16:27:26 -07:00
Awni Hannun
fe597e141c add pinv to doc (#2020) 2025-03-30 15:54:18 -07:00
Yi Wang
72ca1539e0 Remove unused variable in /setup.py (#2026)
This is a follow up of https://github.com/ml-explore/mlx/pull/2011
2025-03-30 12:52:33 -07:00
Awni Hannun
13b26775f1 use minimum deployment target (#2016) 2025-03-28 14:31:53 -07:00
Awni Hannun
05d7118561 causal vector sdpa (#2018)
* causal vector sdpa

* get rid of memory threshold
2025-03-28 12:36:13 -07:00
Awni Hannun
98b901ad66 enable complex gemm (#2017) 2025-03-28 10:45:13 -07:00
Awni Hannun
5580b47291 iinfo and scalar overflow detection (#2009) 2025-03-27 19:54:56 -07:00
Awni Hannun
bc62932984 sdpa specialization for head dim 256 (#2007) 2025-03-27 19:31:25 -07:00
Awni Hannun
a6b5d6e759 revise cmake minimum for doctest (#2014) 2025-03-27 19:30:58 -07:00
Yi Wang
a8931306e1 Remove unused variable in CMakeBuild (#2011)
Fix https://github.com/ml-explore/mlx/issues/2010
2025-03-27 16:00:51 -07:00
Yi Wang
fecdb8717e Polish CONTRIBUTING>md (#2005) 2025-03-25 19:06:34 -07:00
Awni Hannun
916fd273ea wire cache (#2006) 2025-03-25 18:54:01 -07:00
Yi Wang
0da8506552 Update docs for extensions (#2004) 2025-03-25 18:35:03 -07:00
Cheng
eda7a7b43e Do not join threads during process exit on Windows (#1738) 2025-03-25 06:33:08 -07:00
Chunyang Wen
022eabb734 Remove unused import (#1987) 2025-03-24 20:19:32 -07:00
Awni Hannun
aba899cef8 patch bump (#2000) 2025-03-24 12:47:05 -07:00
Jagrit Digani
6a40e1c176 Fix looping limit in causal attention (#1999) 2025-03-24 12:28:00 -07:00
Jesper Stemann Andersen
9307b2ab8b Fixed 32-bit platform support for distributed/ring implementation (#1996)
Replaced unsigned long integer literals with size_t literals in ring implementation, e.g., 1UL with size_t(1).
2025-03-24 08:08:40 -07:00
Jesper Stemann Andersen
522d8d3917 Added missing netinet/in.h include that fixes build on FreeBSD (#1997)
Defines IPPROTO_TCP.
2025-03-24 08:07:34 -07:00
Awni Hannun
a84cc0123f promote mask when needed (#1998) 2025-03-23 19:58:28 -07:00
Andrey Velichkevich
f018e248cd fix(backend): Include algorithm library in Allocator (#1992)
Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
2025-03-22 21:27:51 -07:00
Awni Hannun
cfd7237a80 fix docs (#1991) 2025-03-21 19:58:53 -07:00
Angelos Katharopoulos
4eef8102c9 Distributed layers (#1270) 2025-03-21 13:52:17 -07:00
Angelos Katharopoulos
69e4dd506b Add a ring all gather (#1985) 2025-03-21 13:36:51 -07:00
Angelos Katharopoulos
25814a9458 Disable mpi on version mismatch (#1989) 2025-03-21 13:36:26 -07:00
Awni Hannun
2a980a76ce Add stats and limit to common allocator and enable tests (#1988)
* add stats to common allocator and enable tests

* linux memory and default

* fix
2025-03-21 12:28:36 -07:00
Angelos Katharopoulos
d343782c8b Cross platform libmpi loading (#1975) 2025-03-21 11:23:10 -07:00
Awni Hannun
4e1994e9d7 move memory APIs into top level mlx.core (#1982) 2025-03-21 07:25:12 -07:00
jiyzhang
65a38c452b update the formula of smooth_l1_loss (#1986) 2025-03-21 06:25:23 -07:00
Awni Hannun
7b7e2352cd fix malloc or wait deadlock (#1976) 2025-03-20 16:48:43 -07:00
Awni Hannun
1177d28395 patch bump (#1981) 2025-03-20 15:12:22 -07:00
Awni Hannun
005e7efa64 fix mask in sdpa (#1980)
* fix mask in sdpa

* fix attention mask

* Re-enable routing for array mask

---------

Co-authored-by: Jagrit Digani <digani@apple.com>
2025-03-20 14:53:12 -07:00
Jagrit Digani
b42d13ec84 Update attention tests to show diff, disable array masks (#1978) 2025-03-20 14:25:38 -07:00
Jagrit Digani
9adcd1a650 Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa

* Add fallback

* Update steel::AttnParams

* Fix typo

* WIP, basic causal

* Update tests

* Update benchmarking

* Update masking loop limits

* Add bool masking and update tests

* Update additive mask

* Update benchmarks

* Update benchmarks

* Update tests

* Update for bfloat error

* Update early exit

* Add random seed to tests
2025-03-20 11:01:32 -07:00
Awni Hannun
3c164fca8c Fix multistream GPU deadlock (#1969)
* fix multistream GPU deadlock

* comments
2025-03-20 07:19:47 -07:00
jiyzhang
95e335db7b Update smooth_l1_loss in losses.py (#1974)
According the definition of smooth_l1_loss, the line 

diff = predictions - targets

Should be updated to 

diff = mx.abs(predictions - targets)

After the modification, the result is consistent with PyTorch smooth_l1_loss
2025-03-19 20:19:02 -07:00
Awni Hannun
f90206ad74 Guard nullptr dereference (#1972)
* guard nullptr dereference

* comment
2025-03-19 16:24:10 -07:00
Chunyang Wen
3779150750 refactor: all use schedule (#1973) 2025-03-19 11:24:04 -07:00
Cheng
0a9777aa5c Do not define MLX_VERSION globally (#1966) 2025-03-18 07:12:40 -07:00
Chunyang Wen
45ad06aac8 Fix typo; Fix lint warning when reuse the same name (#1968)
* Fix typo; Fix lint warning when reuse the same name

* Add missing period
2025-03-18 07:12:24 -07:00
Awni Hannun
c6ea2ba329 Use same accumulation precision in gemv as gemm (#1962)
* use same accumulation precision in gemv as gemm

* faster

* fix compile
2025-03-16 07:13:24 -07:00
Awni Hannun
2770a10240 fix grad with inplace updates (#1961) 2025-03-13 19:13:09 -07:00
Awni Hannun
d2a94f9e6a Only compile warnings as errors for circle (#1957) 2025-03-12 13:08:19 -07:00
Awni Hannun
32da94507a fix vmap for flatten (#1955) 2025-03-11 10:42:22 -07:00
Awni Hannun
736a340478 reduce binary size (#1952) 2025-03-11 06:30:44 -07:00
Awni Hannun
117e1355a2 fix copy for large arrays (#1953) 2025-03-10 15:04:25 -07:00
Awni Hannun
3c3e558c60 Support transposed head/seq for kv (#1950)
* support transposed head/seq for kv

* fix flaky test

* nit
2025-03-10 10:53:45 -07:00
Chunyang Wen
cffceda6ee Add type hint for _extra_repr (#1948) 2025-03-10 06:05:36 -07:00
Chunyang Wen
048805ad2c Remove unused modules (#1949) 2025-03-10 06:05:26 -07:00
Chunyang Wen
d14c9fe7ea Add file info when raising errors in save (#1943) 2025-03-08 14:51:04 -08:00
Chunyang Wen
5db90ce822 Fix obsured warning (#1944) 2025-03-08 14:50:39 -08:00
Chunyang Wen
d699cc1330 Fix unreachable warning (#1939)
* Fix unreachable warning

* Update error message
2025-03-07 17:23:04 -08:00
Awni Hannun
c4230747a1 redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
2025-03-06 19:23:38 -08:00
Awni Hannun
5245f12a46 always use json (#1938) 2025-03-06 15:35:56 -08:00
Chunyang Wen
a198b2787e Remove unused modules (#1936) 2025-03-06 14:20:27 -08:00
Chunyang Wen
04edad8c59 Add doc string for path (#1937) 2025-03-06 14:20:09 -08:00
David Wisdom
392b3060b0 Fix typo in randint docstring (#1932)
This commit fixes a typo in the docstring for mlx.core.random.randint() by changing "roadcastable" to "broadcastable".
2025-03-05 21:48:00 -08:00
Chunyang Wen
85b34d59bc Clean unused sys (#1929) 2025-03-05 13:48:03 -08:00
Awni Hannun
f599c11bc8 bump (#1931) 2025-03-05 13:16:53 -08:00
Angelos Katharopoulos
0792ff02ff Only fail when 10 consecutive socket errors occur (#1928) 2025-03-05 13:16:19 -08:00
Alex Barron
fd0d63ba5b Affine quant always in fp32 (#1925)
* do affine quant in fp32

* static cast
2025-03-04 17:50:19 -08:00
Abe Leininger
3835a428c5 Adds nuclear norm support (#1894)
* adjust norm unit test tolerance
2025-03-04 13:26:02 -08:00
Angelos Katharopoulos
9680f72cca Add a multi optimizer (#1916) 2025-03-04 13:16:35 -08:00
Angelos Katharopoulos
a0737273d3 Allow debugging in distributed mode (#1920) 2025-03-04 13:01:10 -08:00
Awni Hannun
e613d0eaf0 SDPA support for small batch (over sequence) queries (#1922)
* batch query sdpa

* batch sdpa for query
2025-03-04 10:59:04 -08:00
Awni Hannun
6bcd6bcf70 fix donation in scan (#1917) 2025-03-03 11:30:59 -08:00
Awni Hannun
ba12e4999a Use a heap for small sizes (#1911)
* use a heap for small sizes

* check if VM
2025-03-03 06:50:57 -08:00
Awni Hannun
4e7cd31d12 Fix slice data size (#1913)
* fix slice data size

* add test
2025-03-02 21:50:42 -08:00
Angelos Katharopoulos
5e6c130d93 RMS norm without scaling (#1915) 2025-02-28 20:26:57 -08:00
Angelos Katharopoulos
5d68082881 Ring docs (#1829) 2025-02-28 11:34:21 -08:00
Angelos Katharopoulos
607181644f Add mlx.distributed_config script (#1902) 2025-02-28 11:16:39 -08:00
Jagrit Digani
89d327075f Enabling fused attention for head dim 128 (#1899)
* Share KV smem

* Fix bfloat error

* Unroll O = S @ V loop

* Perf upgrade

* Remove commented out function

* Add -Wno-c++17-extensions flag to metal flags

* Add -Wno-c++17-extensions flag to metal extension flags
2025-02-26 10:02:06 -08:00
Angelos Katharopoulos
6bf00ef631 Fix ring of 2 and allow scalars in API (#1906) 2025-02-25 17:03:01 -08:00
Awni Hannun
7d042f17fe Double for lapack (#1904)
* double for lapack ops

* add double support for lapack ops
2025-02-25 11:39:36 -08:00
Awni Hannun
28b8079e30 fix double type promotion (#1901) 2025-02-25 06:00:53 -08:00
Awni Hannun
7face5d9fd fix cpu compile (#1897) 2025-02-24 14:10:30 -08:00
Awni Hannun
a44dc4bdb0 fix leaking objc (#1898) 2025-02-24 13:57:59 -08:00
Awni Hannun
2d0f384b6f fix simd erf_inv (#1896) 2025-02-24 13:57:47 -08:00
Awni Hannun
8ff84b5c43 fix version and expose command queue getter (#1892) 2025-02-20 15:25:15 -08:00
Angelos Katharopoulos
10b271d963 Ring update (#1885) 2025-02-20 14:32:31 -08:00
Jesper Stemann Andersen
0ebc8a3d25 Fixed issue where Clang on FreeBSD failed to compile mlx/backend/cpu/quantized.cpp (#1890) 2025-02-20 12:02:12 -08:00
Awni Hannun
bbda0fdbdb Allow non-square lu (#1889) 2025-02-20 08:13:23 -08:00
Jesper Stemann Andersen
c86422bdd4 Added mlx::core::version() returning std::string(MLX_VERSION) (#1819)
* Added version.h providing mlx::core::version() returning std::string(MLX_VERSION)

Also, added MLX_VERSION_MAJOR, MLX_VERSION_MINOR, MLX_VERSION_PATCH, MLX_VERSION_NUMERIC, and accompanying functions.

* Added version.h to mlx.h

* Changed version int functions to be constexpr

* Formatting

* Added handling of MLX_VERSION where only the prefix has major.minor.patch format

* Changed version function to be constexpr
2025-02-19 20:30:19 -08:00
Awni Hannun
c707b2b0a6 Limit compile buffers (#1887)
* limit compile buffers

* maybe not flaky test
2025-02-19 20:28:13 -08:00
Angelos Katharopoulos
78ba24c37d Raise an exception in the rope op if input is integer (#1884) 2025-02-19 14:43:39 -08:00
Angelos Katharopoulos
1a2cb72030 Ensure linspace always contains start and stop (#1883) 2025-02-19 13:53:20 -08:00
Abe Leininger
344a29506e Enforce triangular matrix form in tri_inv (#1876)
* fix tri_inv bug

* Revert "fix tri_inv bug"

This reverts commit b74b290201.

* Make sure that tri_inv returns a triangular matrix

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-02-19 12:42:33 -08:00
Angelos Katharopoulos
71de73a668 Fix convs by reverting #1803 (#1882) 2025-02-18 14:36:34 -08:00
Alex Barron
4c1dfa58b7 xor op on arrays (#1875) 2025-02-17 00:24:53 -08:00
Awni Hannun
5274c3c43f compiler warnings are errors (#1870) 2025-02-17 00:07:49 -08:00
Angelos Katharopoulos
1762793989 Remove unused uniform (#1867) 2025-02-14 15:51:41 -08:00
251 changed files with 12346 additions and 6893 deletions

View File

@@ -24,8 +24,8 @@ jobs:
type: boolean
default: false
macos:
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
xcode: "16.2.0"
resource_class: m2pro.medium
steps:
- checkout
- run:
@@ -89,6 +89,7 @@ jobs:
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run:
name: Install Python package
command: |
@@ -108,6 +109,8 @@ jobs:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build CPP only
command: |
@@ -122,10 +125,15 @@ jobs:
parameters:
xcode_version:
type: string
default: "15.2.0"
default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
steps:
- checkout
- run:
@@ -146,7 +154,9 @@ jobs:
name: Install Python package
command: |
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v
- run:
name: Generate package stubs
command: |
@@ -209,13 +219,18 @@ jobs:
default: "3.9"
xcode_version:
type: string
default: "15.2.0"
default: "16.2.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
- checkout
- run:
@@ -331,7 +346,7 @@ workflows:
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test
- build_documentation
@@ -351,8 +366,70 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation:
filters:
tags:
@@ -375,7 +452,7 @@ workflows:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
@@ -388,7 +465,54 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build:
when:
and:
@@ -399,8 +523,70 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:

View File

@@ -1,6 +1,24 @@
cmake_minimum_required(VERSION 3.25)
project(mlx LANGUAGES C CXX)
if(NOT MLX_VERSION)
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
set(_major ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
set(_minor ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
set(_patch ${CMAKE_MATCH_1})
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
set(MLX_VERSION ${MLX_PROJECT_VERSION})
else()
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
${MLX_VERSION})
endif()
project(
mlx
LANGUAGES C CXX
VERSION ${MLX_PROJECT_VERSION})
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
@@ -24,13 +42,7 @@ option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.23.0)
endif()
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
# --------------------- Processor tests -------------------------
message(
STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
@@ -200,23 +212,13 @@ else()
set(MLX_BUILD_ACCELERATE OFF)
endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
message(STATUS "Downloading json")
FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)

View File

@@ -5,26 +5,26 @@ possible.
## Pull Requests
1. Fork and submit pull requests to the repo.
1. Fork and submit pull requests to the repo.
2. If you've added code that should be tested, add tests.
3. If a change is likely to impact efficiency, run some of the benchmarks before
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
4. If you've changed APIs, update the documentation.
5. Every PR should have passing tests and at least one review.
5. Every PR should have passing tests and at least one review.
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code.
You can also run the formatters manually as follows:
```
clang-format -i file.cpp
```
```
black file.py
```
```shell
clang-format -i file.cpp
```
```shell
black file.py
```
or run `pre-commit run --all-files` to check all files in the repo.
## Issues

View File

@@ -1,7 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from time import time
import mlx.core as mx
import torch

View File

@@ -10,7 +10,12 @@ def layer_norm(x, w, b, eps):
x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True)
return (x - mu) * mx.rsqrt(v + eps) * w + b
y = (x - mu) * mx.rsqrt(v + eps)
if w is not None:
y = y * w
if b is not None:
y = y + b
return y
def time_layer_norm():
@@ -36,6 +41,28 @@ def time_layer_norm():
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(layer_norm_loop, g1, x)
time_fn(layer_norm_loop, g2, x)
time_fn(layer_norm_loop, mx.compile(g1), x)
time_fn(layer_norm_loop, mx.compile(g2), x)
if __name__ == "__main__":
time_layer_norm()

View File

@@ -9,7 +9,10 @@ def rms_norm(x, w, eps):
ot = x.dtype
x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (x * n).astype(ot) * w
y = (x * n).astype(ot)
if w is not None:
y = y * w
return y
def time_rms_norm():
@@ -34,6 +37,27 @@ def time_rms_norm():
time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w)
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(rms_norm_loop, g1, x)
time_fn(rms_norm_loop, g2, x)
time_fn(rms_norm_loop, mx.compile(g1), x)
time_fn(rms_norm_loop, mx.compile(g2), x)
if __name__ == "__main__":
time_rms_norm()

View File

@@ -28,11 +28,34 @@ def bench(f, *args):
return (e - s) * 1e-9
def mlx_sdpa_fused_inner(q, k, v, scale):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
np_dtype = getattr(np, dtype)
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
scale = 1.0 / math.sqrt(D)
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if mask is not None:
if mask == "additive":
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
mask = mx.array(mask_np)
elif mask == "bool":
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
mask = mx.array(mask_np)
return q_mx, k_mx, v_mx, scale, mask
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
@@ -41,6 +64,7 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
B = q.shape[0]
L = q.shape[2]
kL = k.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
@@ -48,10 +72,27 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
else:
scores = mx.softmax(scores, axis=-1)
if mask is not None:
if mask == "causal":
q_offset = max(0, kL - L)
q_indices = mx.arange(q_offset, q_offset + L)
k_indices = mx.arange(kL)
mask = q_indices[:, None] >= k_indices[None]
if n_repeats > 1 and mask.ndim >= 3:
if mask.shape[-3] == 1:
mask = mx.expand_dims(mask, -3)
else:
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
if mask.dtype == mx.bool_:
scores = mx.where(mask, scores, -np.float32(np.inf))
else:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = scores @ v
if n_repeats > 1:
@@ -60,74 +101,55 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
return out
def mlx_spda_unfused(q, k, v, scale, transpose):
q_out = q
def mlx_fused_attn(q, k, v, scale, mask):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
q_t = mx.transpose(q, (0, 2, 1, 3))
k_t = mx.transpose(k, (0, 2, 1, 3))
v_t = mx.transpose(v, (0, 2, 1, 3))
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
return mx.transpose(o_t, (0, 2, 1, 3))
else:
return f(q, k, v, scale=scale, mask=mask)
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
q_out = q
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
mx.eval(q_out)
return q_out
def mlx_spda_fused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
def bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
):
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
)
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
time_mlx_unfused = bench(
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
time_mlx_fused = bench(
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
scale = math.sqrt(1.0 / head_dim)
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
o_mlx_unfused = do_attention(
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
atol = 1e-5 if dtype == "float32" else 2e-4
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
@@ -151,39 +173,51 @@ if __name__ == "__main__":
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 8),
( 1, 2048, 2048, 64, 32, 8),
( 1, 4096, 4096, 64, 32, 8),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
( 1, 1024, 1024, 80, 32, 8),
( 1, 2048, 2048, 80, 32, 8),
( 1, 4096, 4096, 80, 32, 8),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
( 1, 1024, 1024, 128, 32, 8),
( 1, 2048, 2048, 128, 32, 8),
( 1, 4096, 4096, 128, 32, 8),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
masks = [None, "bool", "causal"]
print(
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
)
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)
for mask_in in masks:
time_mlx_fused, time_mlx_unfused = bench_shape(
B,
qsl,
ksl,
head_dim,
n_q_heads,
n_kv_heads,
dtype,
transpose,
mask_in,
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -1,5 +1,7 @@
include(CMakeParseArguments)
# clang format off
#
# ##############################################################################
# Build metal library
#
@@ -11,6 +13,8 @@ include(CMakeParseArguments)
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers)
#
# clang format on
macro(mlx_build_metallib)
# Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
@@ -21,7 +25,7 @@ macro(mlx_build_metallib)
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
# Prepare metallib build command
add_custom_command(

View File

@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = YES
GENERATE_HTML = NO
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES

View File

@@ -22,12 +22,12 @@ You can do that in MLX directly:
This function performs that operation while leaving the implementation and
function transformations to MLX.
However you may need to customize the underlying implementation, perhaps to
make it faster or for custom differentiation. In this tutorial we will go
through adding custom extensions. It will cover:
However, you may want to customize the underlying implementation, perhaps to
make it faster. In this tutorial we will go through adding custom extensions.
It will cover:
* The structure of the MLX library.
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
* Implementing a CPU operation.
* Implementing a GPU operation using metal.
* Adding the ``vjp`` and ``jvp`` function transformation.
* Building a custom extension and binding it to python.
@@ -45,7 +45,7 @@ Operations
Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++:
@@ -55,7 +55,7 @@ C++:
* Scale and sum two vectors element-wise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
* Use NumPy-style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
@@ -66,7 +66,7 @@ C++:
StreamOrDevice s = {} // Stream on which to schedule the operation
);
The simplest way to this operation is in terms of existing operations:
The simplest way to implement this is with existing operations:
.. code-block:: C++
@@ -93,9 +93,9 @@ Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a
defines how to create output arrays given input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
more concrete:
.. code-block:: C++
@@ -128,7 +128,7 @@ more concrete:
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
@@ -153,9 +153,6 @@ more concrete:
private:
float alpha_;
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, array& out);
};
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
@@ -188,7 +185,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = is_floating_point(promoted_dtype)
auto out_dtype = issubdtype(promoted_dtype, float32)
? promoted_dtype
: promote_types(promoted_dtype, float32);
@@ -234,49 +231,57 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
Implementing the CPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by implementing a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Let's start by implementing :meth:`Axpby::eval_cpu`.
Our naive method will go over each element of the output array, find the
The method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
point-wise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++
template <typename T>
void axpby_impl(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
template <typename T>
void axpby_impl(
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_,
mx::Stream stream) {
out.set_data(mx::allocator::malloc(out.nbytes()));
// Collect input and output data pointers
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
T* out_ptr = out.data<T>();
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(y);
encoder.set_output_array(out);
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Launch the CPU kernel
encoder.dispatch([x_ptr = x.data<T>(),
y_ptr = y.data<T>(),
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
}
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < size; out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
});
}
Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
@@ -284,112 +289,32 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
.. code-block:: C++
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == float16) {
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == bfloat16) {
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == complex64) {
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"[Axpby] Only supports floating point types.");
}
}
This is good as a fallback implementation. We can use the ``axpby`` routine
provided by the Accelerate_ framework for a faster implementation in certain
cases:
#. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only use it for ``float32`` types.
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
elements have fixed strides between them. We only direct to Accelerate
if both ``x`` and ``y`` are row contiguous or column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
MLX expects to write the output to a new array. We must copy the elements
of ``y`` into the output and use that as an input to ``axpby``.
Let's write an implementation that uses Accelerate in the right conditions.
It allocates data for the output, copies ``y`` into it, and then calls the
:func:`catlas_saxpby` from accelerate.
.. code-block:: C++
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
For inputs that do not fit the criteria for accelerate, we fall back to
:meth:`Axpby::eval`. With this in mind, let's finish our
:meth:`Axpby::eval_cpu`.
.. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common back-end if specializations are not available
eval(inputs, outputs);
// Dispatch to the correct dtype
if (out.dtype() == mx::float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
}
}
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
primitive here.
Implementing the GPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -466,7 +391,7 @@ below.
auto& d = metal::device(s.device);
// Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel
std::ostringstream kname;
@@ -544,7 +469,7 @@ one we just defined:
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops
// The jvp transform on the primitive can be built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
@@ -556,7 +481,7 @@ one we just defined:
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
}
// If, argnums = {0, 1}, we take contributions from both
// If argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
@@ -810,7 +735,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")
print(f"c is correct: {mx.all(c == 6.0).item()}")
Output:
@@ -818,13 +743,13 @@ Output:
c shape: [3, 4]
c dtype: float32
c correctness: True
c is correct: True
Results
^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we first defined on the CPU.
with the naive :meth:`simple_axpby` we first defined.
.. code-block:: python
@@ -832,13 +757,11 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
from mlx_sample_extensions import axpby
import time
mx.set_default_device(mx.cpu)
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
M = 256
N = 512
M = 4096
N = 4096
x = mx.random.normal((M, N))
y = mx.random.normal((M, N))
@@ -849,24 +772,24 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
def bench(f):
# Warm up
for i in range(100):
for i in range(5):
z = f(x, y, alpha, beta)
mx.eval(z)
# Timed run
s = time.time()
for i in range(5000):
for i in range(100):
z = f(x, y, alpha, beta)
mx.eval(z)
e = time.time()
return e - s
return 1000 * (e - s) / 100
simple_time = bench(simple_axpby)
custom_time = bench(axpby)
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
modest improvements right away!
This operation is now good to be used to build other operations, in

View File

@@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft
python/linalg
python/metal
python/memory_management
python/nn
python/optimizers
python/distributed

View File

@@ -20,5 +20,6 @@ Linear Algebra
eigh
lu
lu_factor
pinv
solve
solve_triangular

View File

@@ -0,0 +1,16 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@@ -8,13 +8,5 @@ Metal
is_available
device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture
stop_capture

View File

@@ -174,6 +174,7 @@ In detail:
value_and_grad
quantize
average_gradients
.. toctree::

View File

@@ -36,10 +36,12 @@ Operations
bitwise_or
bitwise_xor
block_masked_mm
broadcast_arrays
broadcast_to
ceil
clip
concatenate
contiguous
conj
conjugate
convolve

View File

@@ -9,6 +9,7 @@ Transforms
:toctree: _autosummary
eval
async_eval
compile
custom_function
disable_compile

View File

@@ -5,21 +5,27 @@ Distributed Communication
.. currentmodule:: mlx.core.distributed
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
provide distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. You can
see a list of the supported operations in the :ref:`API docs<distributed>`.
MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the
moment we support two different communication backends:
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets and should be
faster for thunderbolt connections.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
.. note::
A lot of operations may not be supported or not as fast as they should be.
Some operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. The minimal distributed program in MLX is as simple as:
A distributed program in MLX is as simple as:
.. code:: python
@@ -30,74 +36,79 @@ machine. The minimal distributed program in MLX is as simple as:
print(world.rank(), x)
The program above sums the array ``mx.ones(10)`` across all
distributed processes. If simply run with ``python``, however, only one
process is launched and no distributed communication takes place.
distributed processes. However, when this script is run with ``python`` only
one process is launched and no distributed communication takes place. Namely,
all operations in ``mx.distributed`` are noops when the distributed group has a
size of one. This property allows us to avoid code that checks if we are in a
distributed setting similar to the one below:
To launch the program in distributed mode we need to use ``mpirun`` or
``mpiexec`` depending on the MPI installation. The simplest possible way is the
following:
.. code:: python
import mlx.core as mx
x = ...
world = mx.distributed.init()
# No need for the check we can simply do x = mx.distributed.all_sum(x)
if world.size() > 1:
x = mx.distributed.all_sum(x)
Running Distributed Programs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
Continuing with our initial example we can run it on localhost with 4 processes using
.. code:: shell
$ mpirun -np 2 python test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
$ mlx.launch -n 4 my_script.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
print 4 etc.
Installing MPI
---------------
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
We can also run it on some remote hosts by providing their IPs (provided that
the script exists on all hosts and they are reachable by ssh)
.. code:: shell
$ conda install conda-forge::openmpi
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
Consult the dedicated :doc:`usage guide<launching_distributed>` for more
information on using ``mlx.launch``.
.. code:: shell
Selecting Backend
^^^^^^^^^^^^^^^^^
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
Setting up Remote Hosts
-----------------------
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
full path to force all machines to use a specific path.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
both fail then a singleton group is created.
.. note::
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
After a distributed backend is successfully initialized :func:`init` will
return **the same backend** if called without arguments or with backend set to
``any``.
An easy way to pass the host names to MPI is using a host file. A host file
looks like the following, where ``host1`` and ``host2`` should be the fully
qualified domain names or IPs for these hosts.
The following examples aim to clarify the backend initialization logic in MLX:
.. code::
.. code:: python
host1 slots=1
host2 slots=1
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
world = mx.distributed.init(backend="mpi")
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
When using MLX, it is very likely that you want to use 1 slot per host, ie one
process per host. The hostfile also needs to contain the current
host if you want to run on the local host. Passing the host file to
``mpirun`` is simply done using the ``--hostfile`` command line argument.
# Case 2: Initialize any backend
world = mx.distributed.init(backend="any") # equivalent to no arguments
world2 = mx.distributed.init() # same as above
# Case 3: Initialize both backends at the same time
world_mpi = mx.distributed.init(backend="mpi")
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first!
Training Example
----------------
@@ -155,13 +166,179 @@ everything else remaining the same.
optimizer.update(model, grads)
return loss
Tuning All Reduce
-----------------
Utilizing ``nn.average_gradients``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
We are working on improving the performance of all reduce on MLX but for now
the two main things one can do to extract the most out of distributed training with MLX are:
Although the code example above works correctly; it performs one communication
per gradient. It is significantly more efficient to aggregate several gradients
together and perform fewer communication steps.
1. Perform a few large reductions instead of many small ones to improve
bandwidth and latency
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
connections between each host to improve bandwidth
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
almost identical to the example above:
.. code:: python
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mlx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
Getting Started with MPI
------------------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. Launching distributed MLX programs that use MPI can be done with
``mpirun`` as expected. However, in the following examples we will be using
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
library.
The simplest possible usage is the following which, assuming the minimal
example in the beginning of this page, should result in:
.. code:: shell
$ mlx.launch --backend mpi -n 2 test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
print 4 etc.
Installing MPI
^^^^^^^^^^^^^^
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install conda-forge::openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
$ # or simply
$ mlx.launch -n 2 test.py
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
Tuning MPI All Reduce
^^^^^^^^^^^^^^^^^^^^^
.. note::
For faster all reduce consider using the ring backend either with Thunderbolt
connections or over Ethernet.
Configure MPI to use N tcp connections between each host to improve bandwidth
by passing ``--mca btl_tcp_links N``.
Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use.
Getting Started with Ring
-------------------------
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result :func:`send` and :func:`recv` with
arbitrary sender and receiver is not supported in the ring backend.
Defining a Ring
^^^^^^^^^^^^^^^
The easiest way to define and use a ring is via a JSON hostfile and the
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
rank 0, ``hostname2`` rank 1 etc.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
]
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
connection from ``123.123.123.4`` and so on and so forth.
Thunderbolt Ring
^^^^^^^^^^^^^^^^
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
To validate your connection without configuring anything
``mlx.distributed_config`` can also plot the ring using DOT format.
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
dot -Tpng ring.dot >ring.png
open ring.png
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
corresponding to that cable in nodes ``i`` and ``i + 1``.
* Set up a unique subnetwork connecting the two nodes for the corresponding
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.

View File

@@ -0,0 +1,105 @@
:orphan:
.. _usage_launch_distributed:
Launching Distributed Programs
==============================
.. currentmodule:: mlx.core.distributed
Installing the MLX python package provides a helper script ``mlx.launch`` that
can be used to run python scripts distributed on several nodes. It allows
launching using either the MPI backend or the ring backend. See the
:doc:`distributed docs <distributed>` for the different backends.
Usage
-----
The minimal usage example of ``mlx.launch`` is simply
.. code:: shell
mlx.launch --hosts ip1,ip2 my_script.py
or for testing on localhost
.. code:: shell
mlx.launch -n 2 my_script.py
The ``mlx.launch`` command connects to the provided host and launches the input
script on each host. It monitors each of the launched processes and terminates
the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
It also takes care of forwarding the output of each remote process to stdout
and stderr respectively.
Providing Hosts
^^^^^^^^^^^^^^^^
Hosts can be provided as command line arguments, like above, but the way that
allows to fully define a list of hosts is via a JSON hostfile. The hostfile has
a very simple schema. It is simply a list of objects that define each host via
a hostname to ssh to and a list of IPs to utilize for the communication.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]},
{"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]}
]
You can use ``mlx.distributed_config --over ethernet`` to create a hostfile
with IPs corresponding to the ``en0`` interface.
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^^
In order to be able to launch the script on each host we need to be able to
connect via ssh. Moreover the input script and python binary need to be on each
host and on the same path. A good checklist to debug errors is the following:
* ``ssh hostname`` works without asking for password or host confirmation
* the python binary is available on all hosts at the same path. You can use
``mlx.launch --print-python`` to see what that path is.
* the script you want to run is available on all hosts at the same path
.. _mpi_specifics:
MPI Specifics
-------------
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
* The IPs in the hostfile are ignored
* The ssh connectivity requirement is stronger as every node needs to be able
to connect to every other node
* ``mpirun`` needs to be available on every node at the same path
Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance
to choose a specific interface for the byte-transfer-layer of MPI we can call
``mlx.launch`` as follows:
.. code:: shell
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
.. _ring_specifics:
Ring Specifics
--------------
The ring backend, which is also the default backend, can be explicitly selected
with the argument ``--backend ring``. The ring backend has some specific
requirements and arguments that are different to MPI:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.

View File

@@ -10,7 +10,6 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED)
find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
@@ -21,6 +20,12 @@ execute_process(
OUTPUT_VARIABLE nanobind_ROOT)
find_package(nanobind CONFIG REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
# ----------------------------- Extensions -----------------------------
# Add library

View File

@@ -1,20 +1,14 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023-2025 Apple Inc.
#include <cassert>
#include <iostream>
#include <sstream>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/utils.h"
#include "axpby/axpby.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h>
#endif
#ifdef _METAL_
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
@@ -76,136 +70,65 @@ void axpby_impl(
const mx::array& y,
mx::array& out,
float alpha_,
float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
float beta_,
mx::Stream stream) {
out.set_data(mx::allocator::malloc(out.nbytes()));
// Collect input and output data pointers
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
T* out_ptr = out.data<T>();
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(y);
encoder.set_output_array(out);
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Launch the CPU kernel
encoder.dispatch([x_ptr = x.data<T>(),
y_ptr = y.data<T>(),
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
// Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < size; out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
});
}
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == mx::float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
} else if (out.dtype() == mx::complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
}
}
///////////////////////////////////////////////////////////////////////////////
// Primitive Accelerate Backend Implementation
///////////////////////////////////////////////////////////////////////////////
#ifdef ACCELERATE_NEW_LAPACK
template <typename T>
void axpby_impl_accelerate(
const mx::array& x,
const mx::array& y,
mx::array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, mx::CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == mx::float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common backend if specializations are not available
eval(inputs, outputs);
}
#else // Accelerate not available
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
eval(inputs, outputs);
}
#endif
///////////////////////////////////////////////////////////////////////////////
// Primitive Metal Backend Implementation
///////////////////////////////////////////////////////////////////////////////
@@ -217,7 +140,6 @@ void Axpby::eval_gpu(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
@@ -236,12 +158,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization
if (contiguous_kernel) {
out.set_data(
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
mx::allocator::malloc(x.data_size() * out.itemsize()),
x.data_size(),
x.strides(),
x.flags());
} else {
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
out.set_data(mx::allocator::malloc(out.nbytes()));
}
// Resolve name of kernel (corresponds to axpby.metal)

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2025 Apple Inc.
#pragma once
@@ -85,11 +85,6 @@ class Axpby : public mx::Primitive {
private:
float alpha_;
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs);
};
} // namespace my_ext

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2025 Apple Inc.
#include <metal_stdlib>

View File

@@ -19,6 +19,11 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
# Define MLX_VERSION only in the version.cpp file.
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
if(MSVC)
# Disable some MSVC warnings to speed up compilation.
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)

View File

@@ -4,12 +4,11 @@
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size, /* allow_swap */ true);
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
@@ -22,45 +21,4 @@ void free(Buffer buffer) {
allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size, bool) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
Buffer malloc_or_wait(size_t size) {
auto buffer = allocator().malloc(size);
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
buffer = allocator().malloc(size);
}
// Try swapping if needed
if (size && !buffer.ptr()) {
buffer = allocator().malloc(size, /* allow_swap = */ true);
}
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace mlx::core::allocator

View File

@@ -32,14 +32,10 @@ Buffer malloc(size_t size);
void free(Buffer buffer);
// Wait for running tasks to finish and free up memory
// if allocation fails
Buffer malloc_or_wait(size_t size);
class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
@@ -53,16 +49,4 @@ class Allocator {
Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator

View File

@@ -56,6 +56,18 @@ std::vector<array> array::make_arrays(
return outputs;
}
array array::unsafe_weak_copy(const array& other) {
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
cpy.set_data(
other.buffer(),
other.data_size(),
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
return cpy;
}
array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>(
Shape{static_cast<ShapeElem>(data.size())},
@@ -76,35 +88,27 @@ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
set_data(data, deleter);
}
array::array(
allocator::Buffer data,
Shape shape,
Dtype dtype,
Strides strides,
size_t data_size,
Flags flags,
Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, data_size, std::move(strides), flags, deleter);
}
void array::detach() {
array_desc_->primitive = nullptr;
for (auto& s : array_desc_->siblings) {
s.array_desc_->primitive = nullptr;
}
for (auto& s : array_desc_->siblings) {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
array_desc_->primitive = nullptr;
}
bool array::is_available() const {
if (status() == Status::available) {
return true;
} else if (status() == Status::evaluated && event().is_signaled()) {
} else if (
status() == Status::evaluated &&
(!event().valid() || event().is_signaled())) {
set_status(Status::available);
return true;
}
@@ -113,7 +117,10 @@ bool array::is_available() const {
void array::wait() {
if (!is_available()) {
event().wait();
if (event().valid()) {
event().wait();
detach_event();
}
set_status(Status::available);
}
}
@@ -174,34 +181,13 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::move_shared_buffer(
array other,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}
void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
array::~array() {
if (array_desc_ == nullptr) {
return;
}
// Ignore arrays that might be detached during eval
if (status() == array::Status::scheduled) {
// Detached/detaching
if (array_desc_->primitive == nullptr) {
return;
}

View File

@@ -199,6 +199,13 @@ class array {
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs);
/**
* Get a new array that refers to the same data as the input but with a
* non-owning pointer to it. Note the array is detached from the graph and has
* no inputs, siblings or primitive.
*/
static array unsafe_weak_copy(const array& other);
/** A unique identifier for an array. */
std::uintptr_t id() const {
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
@@ -243,18 +250,6 @@ class array {
bool col_contiguous : 1;
};
/** Build an array from all the info held by the array description. Including
* the buffer, strides, flags.
*/
explicit array(
allocator::Buffer data,
Shape shape,
Dtype dtype,
Strides strides,
size_t data_size,
Flags flags,
Deleter deleter = allocator::free);
/** The array's primitive. */
Primitive& primitive() const {
return *(array_desc_->primitive);
@@ -365,11 +360,6 @@ class array {
// For example, the status of `x` in `auto x = a + b`.
unscheduled,
// The ouptut of a computation which has been scheduled but `eval_*` has
// not yet been called on the array's primitive. A possible
// status of `x` in `auto x = a + b; eval(x);`
scheduled,
// The array's `eval_*` function has been run, but the computation is not
// necessarily complete. The array will have memory allocated and if it is
// not a tracer then it will be detached from the graph.
@@ -406,6 +396,10 @@ class array {
array_desc_->event = std::move(e);
}
void detach_event() const {
array_desc_->event = Event{};
}
// Mark the array as a tracer array (true) or not.
void set_tracer(bool is_tracer) {
array_desc_->is_tracer = is_tracer;
@@ -431,15 +425,6 @@ class array {
void copy_shared_buffer(const array& other);
void move_shared_buffer(
array other,
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_;
}

View File

@@ -38,25 +38,20 @@ inline void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt,
bool donate_with_move = false) {
BinaryOpType bopt) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
allocator::malloc(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
@@ -64,14 +59,10 @@ inline void set_binary_op_output_data(
break;
case BinaryOpType::VectorScalar:
if (a_donatable) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
out.copy_shared_buffer(a);
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -79,20 +70,12 @@ inline void set_binary_op_output_data(
break;
case BinaryOpType::VectorVector:
if (a_donatable) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
out.copy_shared_buffer(a);
} else if (b_donatable) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -100,20 +83,12 @@ inline void set_binary_op_output_data(
break;
case BinaryOpType::General:
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
out.copy_shared_buffer(a);
} else if (
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
break;
}

View File

@@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// rely on data_size anyway.
size_t data_size = out.size();
return move_or_copy(in, out, strides_, flags, data_size, offset_);
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void broadcast(const array& in, array& out) {
@@ -56,7 +56,7 @@ void broadcast(const array& in, array& out) {
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
move_or_copy(in, out, strides, flags, in.data_size());
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
@@ -69,7 +69,7 @@ void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);
out.copy_shared_buffer(inputs[0]);
}
void CustomTransforms::eval(
@@ -78,7 +78,7 @@ void CustomTransforms::eval(
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
move_or_copy(inputs[j], outputs[i]);
outputs[i].copy_shared_buffer(inputs[j]);
}
}
@@ -87,7 +87,7 @@ void Depends::eval(
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
move_or_copy(inputs[i], outputs[i]);
outputs[i].copy_shared_buffer(inputs[i]);
}
}
@@ -98,12 +98,12 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
for (auto ax : axes_) {
strides.insert(strides.begin() + ax, 1);
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {
@@ -210,7 +210,7 @@ void shared_buffer_reshape(
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
move_or_copy(in, out, out_strides, flags, in.data_size());
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Split::eval(
@@ -276,12 +276,12 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
strides.push_back(in.strides(i));
}
}
move_or_copy(in, out, strides, in.flags(), in.data_size());
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);
out.copy_shared_buffer(inputs[0]);
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
@@ -315,7 +315,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
b_stride *= out.shape(ri);
}
}
move_or_copy(in, out, out_strides, flags, in.data_size());
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -161,8 +161,7 @@ void compiled_allocate_outputs(
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers /* = false */) {
bool contiguous) {
if (contiguous) {
int o = 0;
Strides strides;
@@ -178,11 +177,7 @@ void compiled_allocate_outputs(
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o++].move_shared_buffer(in);
} else {
outputs[o++].copy_shared_buffer(in);
}
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
@@ -193,7 +188,7 @@ void compiled_allocate_outputs(
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
allocator::malloc(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
@@ -210,18 +205,13 @@ void compiled_allocate_outputs(
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
} else {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
}
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
}
}
}

View File

@@ -62,7 +62,6 @@ void compiled_allocate_outputs(
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers = false);
bool contiguous);
} // namespace mlx::core

View File

@@ -22,4 +22,25 @@ enum class CopyType {
GeneralGeneral
};
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
return true;
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
return false;
}
}
} // namespace mlx::core

View File

@@ -3,7 +3,8 @@
#include <algorithm>
#include <utility>
#include "mlx/backend/common/load.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace {
@@ -26,26 +27,31 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core {
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianness_) {
reader->read(out.data<char>(), out.nbytes(), offset);
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
break;
case 4:
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
break;
case 8:
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
break;
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto read_task = [out_ptr = out.data<char>(),
size = out.size(),
itemsize = out.itemsize(),
offset = offset_,
reader = reader_,
swap_endianness_ = swap_endianness_]() mutable {
reader->read(out_ptr, size * itemsize, offset);
if (swap_endianness_) {
switch (itemsize) {
case 2:
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 4:
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 8:
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
}
}
}
};
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); });
}
} // namespace mlx::core

View File

@@ -1,14 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/array.h"
#include "mlx/io/load.h"
namespace mlx::core {
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianess);
} // namespace mlx::core

View File

@@ -14,6 +14,10 @@ std::tuple<int64_t, Strides> prepare_slice(
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
}
// Normalize the offset
if (data_offset < 0) {
data_offset += in.data_size();
}
return std::make_tuple(data_offset, inp_strides);
}
@@ -32,7 +36,7 @@ void shared_buffer_slice(
flags.col_contiguous = is_col_contiguous;
flags.contiguous = (no_bsx_size == data_size);
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
void slice(
@@ -54,9 +58,10 @@ void slice(
data_end += end_idx * in.strides()[i];
}
}
// data_end can be -1
size_t data_size =
data_end < 0 ? (data_offset - data_end) : (data_end - data_offset);
if (data_end < 0) {
data_end += in.data_size();
}
size_t data_size = (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}

View File

@@ -36,15 +36,10 @@ inline void set_ternary_op_output_data(
const array& b,
const array& c,
array& out,
TernaryOpType topt,
bool donate_with_move = false) {
auto maybe_donate = [&out, donate_with_move](const array& x) {
TernaryOpType topt) {
auto maybe_donate = [&out](const array& x) {
if (is_donatable(x, out)) {
if (donate_with_move) {
out.move_shared_buffer(x);
} else {
out.copy_shared_buffer(x);
}
out.copy_shared_buffer(x);
return true;
}
return false;
@@ -53,12 +48,12 @@ inline void set_ternary_op_output_data(
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
allocator::malloc(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
@@ -69,7 +64,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
break;
}

View File

@@ -4,28 +4,6 @@
namespace mlx::core {
void move_or_copy(const array& in, array& out) {
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.copy_shared_buffer(in);
}
}
void move_or_copy(
const array& in,
array& out,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
if (in.is_donatable()) {
out.move_shared_buffer(in, strides, flags, data_size, offset);
} else {
out.copy_shared_buffer(in, strides, flags, data_size, offset);
}
}
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape,
const std::vector<Strides>& strides,

View File

@@ -159,15 +159,6 @@ inline bool is_donatable(const array& in, const array& out) {
in.buffer_size() <= out.nbytes() + donation_extra;
}
void move_or_copy(const array& in, array& out);
void move_or_copy(
const array& in,
array& out,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset = 0);
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
void shared_buffer_reshape(

View File

@@ -44,7 +44,9 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
@@ -56,6 +58,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -65,13 +68,14 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if(MLX_BUILD_ACCELERATE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
endif()
if(IOS)

View File

@@ -2,76 +2,27 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core {
namespace {
template <typename T>
void arange(T start, T next, array& out, size_t size) {
void arange(T start, T next, array& out, size_t size, Stream stream) {
auto ptr = out.data<T>();
auto step_size = next - start;
for (int i = 0; i < size; ++i) {
ptr[i] = start;
start += step_size;
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([ptr, start, step_size, size]() mutable {
for (int i = 0; i < size; ++i) {
ptr[i] = start;
start += step_size;
}
});
}
} // namespace
void arange(
const std::vector<array>& inputs,
array& out,
double start,
double step) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
break;
case uint8:
arange<uint8_t>(start, start + step, out, out.size());
break;
case uint16:
arange<uint16_t>(start, start + step, out, out.size());
break;
case uint32:
arange<uint32_t>(start, start + step, out, out.size());
break;
case uint64:
arange<uint64_t>(start, start + step, out, out.size());
break;
case int8:
arange<int8_t>(start, start + step, out, out.size());
break;
case int16:
arange<int16_t>(start, start + step, out, out.size());
break;
case int32:
arange<int32_t>(start, start + step, out, out.size());
break;
case int64:
arange<int64_t>(start, start + step, out, out.size());
break;
case float16:
arange<float16_t>(start, start + step, out, out.size());
break;
case float32:
arange<float>(start, start + step, out, out.size());
break;
case float64:
arange<double>(start, start + step, out, out.size());
break;
case bfloat16:
arange<bfloat16_t>(start, start + step, out, out.size());
break;
case complex64:
arange<complex64_t>(start, start + step, out, out.size());
break;
}
}
} // namespace mlx::core

View File

@@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -17,15 +18,18 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
Shape shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();
for (uint32_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto in_ptr = in.data<InT>() + loc;
auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0;
InT v = (*in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
op(j, (*in_ptr), &ind_v, &v);
InT v = (*local_in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
op(j, (*local_in_ptr), &ind_v, &v);
}
out.data<uint32_t>()[i] = ind_v;
out_ptr[i] = ind_v;
}
}
@@ -64,52 +68,59 @@ void arg_reduce_dispatch(
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (in.dtype()) {
case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break;
case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break;
case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break;
case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break;
case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break;
case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break;
case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break;
case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break;
case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break;
case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break;
case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break;
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
reduce_type_ = reduce_type_,
axis_ = axis_]() mutable {
switch (in.dtype()) {
case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break;
case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break;
case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break;
case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break;
case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break;
case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break;
case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break;
case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break;
case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break;
case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break;
case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break;
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;
}
});
}
} // namespace mlx::core

View File

@@ -8,6 +8,7 @@
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/binary_two.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -16,51 +17,218 @@ namespace mlx::core {
namespace {
template <typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool>(a, b, out, op);
break;
case uint8:
binary_op<uint8_t, bool>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t, bool>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t, bool>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t, bool>(a, b, out, op);
break;
case int8:
binary_op<int8_t, bool>(a, b, out, op);
break;
case int16:
binary_op<int16_t, bool>(a, b, out, op);
break;
case int32:
binary_op<int32_t, bool>(a, b, out, op);
break;
case int64:
binary_op<int64_t, bool>(a, b, out, op);
break;
case float16:
binary_op<float16_t, bool>(a, b, out, op);
break;
case float32:
binary_op<float, bool>(a, b, out, op);
break;
case float64:
binary_op<double, bool>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, op);
break;
case complex64:
binary_op<complex64_t, bool>(a, b, out, op);
break;
}
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports non-complex floating point types.");
}
});
}
template <typename Op>
void binary_int(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace
@@ -69,7 +237,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Add());
binary(a, b, out, detail::Add(), stream());
}
void DivMod::eval_cpu(
@@ -78,70 +246,89 @@ void DivMod::eval_cpu(
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto integral_op = [](auto x, auto y) {
return std::make_pair(x / y, x % y);
};
auto float_op = [](auto x, auto y) {
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
};
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, integral_op);
case uint8:
binary_op<uint8_t>(a, b, outputs, integral_op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, integral_op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, integral_op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, integral_op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, integral_op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, integral_op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, integral_op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, integral_op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, float_op);
break;
case float32:
binary_op<float>(a, b, outputs, float_op);
break;
case float64:
binary_op<double>(a, b, outputs, float_op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, float_op);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out_a = array::unsafe_weak_copy(out_a),
out_b = array::unsafe_weak_copy(out_b),
bopt]() mutable {
auto integral_op = [](auto x, auto y) {
return std::make_pair(x / y, x % y);
};
auto float_op = [](auto x, auto y) {
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
};
switch (out_a.dtype()) {
case bool_:
binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);
case uint8:
binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case uint16:
binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case uint32:
binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case uint64:
binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int8:
binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int16:
binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int32:
binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int64:
binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case float16:
binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);
break;
case float32:
binary_op<float>(a, b, out_a, out_b, float_op, bopt);
break;
case float64:
binary_op<double>(a, b, out_a, out_b, float_op, bopt);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
});
}
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Divide());
binary(a, b, out, detail::Divide(), stream());
}
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Remainder());
binary(a, b, out, detail::Remainder(), stream());
}
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -149,181 +336,143 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& a = inputs[0];
auto& b = inputs[1];
if (equal_nan_) {
switch (a.dtype()) {
case float16:
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
break;
case float32:
binary_op<float, bool>(a, b, out, detail::NaNEqual());
break;
case float64:
binary_op<double, bool>(a, b, out, detail::NaNEqual());
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
break;
case complex64:
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual());
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case float16:
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
});
} else {
comparison_op(a, b, out, detail::Equal());
comparison_op(a, b, out, detail::Equal(), stream());
}
}
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater());
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
}
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
}
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less());
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
}
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
}
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::LogAddExp());
break;
case float32:
binary_op<float>(a, b, out, detail::LogAddExp());
break;
case float64:
binary_op<double>(a, b, out, detail::LogAddExp());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
break;
default:
throw std::runtime_error(
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
}
binary_float(a, b, out, detail::LogAddExp(), stream());
}
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
binary(in1, in2, out, detail::LogicalAnd(), stream());
}
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
binary(in1, in2, out, detail::LogicalOr(), stream());
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Maximum());
binary(a, b, out, detail::Maximum(), stream());
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Minimum());
binary(a, b, out, detail::Minimum(), stream());
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Multiply());
binary(a, b, out, detail::Multiply(), stream());
}
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Power());
binary(a, b, out, detail::Power(), stream());
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Subtract());
binary(a, b, out, detail::Subtract(), stream());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd());
binary_int(a, b, out, detail::BitwiseAnd(), stream());
break;
case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr());
binary_int(a, b, out, detail::BitwiseOr(), stream());
break;
case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor());
binary_int(a, b, out, detail::BitwiseXor(), stream());
break;
case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift());
binary_int(a, b, out, detail::LeftShift(), stream());
break;
case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift());
binary_int(a, b, out, detail::RightShift(), stream());
break;
}
}
@@ -332,23 +481,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::ArcTan2());
break;
case float32:
binary_op<float>(a, b, out, detail::ArcTan2());
break;
case float64:
binary_op<double>(a, b, out, detail::ArcTan2());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
break;
default:
throw std::runtime_error(
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
}
binary_float(a, b, out, detail::ArcTan2(), stream());
}
} // namespace mlx::core

View File

@@ -3,7 +3,6 @@
#pragma once
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
@@ -14,22 +13,18 @@ namespace mlx::core {
template <typename Op>
struct VectorScalar {
Op op;
VectorScalar(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *b;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
simd::store(dst, Op{}(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
dst += N;
a += N;
size -= N;
}
while (size-- > 0) {
*dst = op(*a, scalar);
*dst = Op{}(*a, scalar);
dst++;
a++;
}
@@ -38,22 +33,18 @@ struct VectorScalar {
template <typename Op>
struct ScalarVector {
Op op;
ScalarVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *a;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
simd::store(dst, Op{}(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
dst += N;
b += N;
size -= N;
}
while (size-- > 0) {
*dst = op(scalar, *b);
*dst = Op{}(scalar, *b);
dst++;
b++;
}
@@ -62,22 +53,18 @@ struct ScalarVector {
template <typename Op>
struct VectorVector {
Op op;
VectorVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
simd::store(dst, Op{}(simd::load<T, N>(a), simd::load<T, N>(b)));
dst += N;
a += N;
b += N;
size -= N;
}
while (size-- > 0) {
*dst = op(*a, *b);
*dst = Op{}(*a, *b);
dst++;
a++;
b++;
@@ -90,7 +77,6 @@ void binary_op_dims(
const T* a,
const T* b,
U* out,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
@@ -104,12 +90,12 @@ void binary_op_dims(
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
binary_op_dims<T, U, Op, D - 1, Strided>(
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
a, b, out, shape, a_strides, b_strides, out_strides, axis + 1);
} else {
if constexpr (Strided) {
op(a, b, out, stride_out);
Op{}(a, b, out, stride_out);
} else {
*out = op(*a, *b);
*out = Op{}(*a, *b);
}
}
out += stride_out;
@@ -120,66 +106,38 @@ void binary_op_dims(
template <typename T, typename U, bool Strided, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op,
const T* a,
const T* b,
U* out,
int dim,
int size,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
switch (dim) {
case 1:
binary_op_dims<T, U, Op, 1, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
a, b, out, shape, a_strides, b_strides, out_strides, 0);
return;
case 2:
binary_op_dims<T, U, Op, 2, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
a, b, out, shape, a_strides, b_strides, out_strides, 0);
return;
case 3:
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
a, b, out, shape, a_strides, b_strides, out_strides, 0);
return;
}
ContiguousIterator a_it(shape, a_strides, dim - 3);
ContiguousIterator b_it(shape, b_strides, dim - 3);
auto stride = out_strides[dim - 4];
for (int64_t elem = 0; elem < a.size(); elem += stride) {
for (int64_t elem = 0; elem < size; elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_ptr + elem,
op,
a + a_it.loc,
b + b_it.loc,
out + elem,
shape,
a_strides,
b_strides,
@@ -191,40 +149,41 @@ void binary_op_dispatch_dims(
}
template <typename T, typename U, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
// The full computation is scalar scalar so call the base op once
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_ptr = out.data<U>();
if (bopt == BinaryOpType::ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
*out_ptr = Op{}(*a_ptr, *b_ptr);
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
return;
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
const auto& a_strides = new_strides[0];
const auto& b_strides = new_strides[1];
const auto& strides = new_strides[2];
auto& a_strides = new_strides[0];
auto& b_strides = new_strides[1];
auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
@@ -248,7 +207,8 @@ void binary_op(const array& a, const array& b, array& out, Op op) {
auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
// Case 1: LxM and FxM where L and F are broadcastable and M is row
// contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
@@ -275,99 +235,59 @@ void binary_op(const array& a, const array& b, array& out, Op op) {
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
VectorVector{op},
binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
VectorScalar{op},
binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
ScalarVector{op},
binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
default:
binary_op_dispatch_dims<T, U, false>(
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
binary_op_dispatch_dims<T, U, false, Op>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
}
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
binary_op<T, T>(a, b, out, op);
}
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
break;
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
case float16:
binary_op<float16_t>(a, b, out, op);
break;
case float32:
binary_op<float>(a, b, out, op);
break;
case float64:
binary_op<double>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, op);
break;
case complex64:
binary_op<complex64_t>(a, b, out, op);
break;
}
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
binary_op<T, T, Op>(a, b, out, bopt);
}
} // namespace mlx::core

View File

@@ -58,14 +58,14 @@ void binary_op_dispatch_dims(
Op op) {
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>();
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
int ndim = shape.size();
switch (ndim) {
case 1:
@@ -120,14 +120,10 @@ template <typename T, typename U = T, typename Op>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
array& out_a,
array& out_b,
Op op,
BinaryOpType bopt) {
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
@@ -141,14 +137,14 @@ void binary_op(
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
} else if (bopt == BinaryOpType::ScalarVector) {
for (size_t i = 0; i < b.size(); ++i) {
for (size_t i = 0; i < b.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
} else if (bopt == BinaryOpType::VectorScalar) {
for (size_t i = 0; i < a.size(); ++i) {
for (size_t i = 0; i < a.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
@@ -165,58 +161,6 @@ void binary_op(
}
}
template <typename Op>
void binary(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, op);
break;
case uint8:
binary_op<uint8_t>(a, b, outputs, op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, op);
break;
case float32:
binary_op<float>(a, b, outputs, op);
break;
case float64:
binary_op<double>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op);
break;
case complex64:
binary_op<complex64_t>(a, b, outputs, op);
break;
}
}
} // namespace
} // namespace mlx::core

View File

@@ -2,13 +2,15 @@
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
void cholesky_impl(const array& a, array& factor, bool upper) {
template <typename T>
void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
// (A)ᵀ = A
@@ -16,59 +18,68 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
// triangular matrix, so uplo is the opposite of what we would expect from
// upper
char uplo = (upper) ? 'L' : 'U';
// The decomposition is computed in place, so just copy the input to the
// output.
copy(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(factor);
encoder.dispatch([matrix = factor.data<T>(),
upper,
N = a.shape(-1),
size = a.size()]() mutable {
char uplo = (upper) ? 'L' : 'U';
size_t num_matrices = size / (N * N);
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info;
potrf<T>(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
float* matrix = factor.data<float>();
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info;
MLX_LAPACK_FUNC(spotrf)
(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[cholesky] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[Cholesky::eval_cpu] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
}
matrix += N;
}
matrix += N;
}
}
});
}
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Cholesky::eval] only supports float32.");
switch (inputs[0].dtype()) {
case float32:
cholesky_impl<float>(inputs[0], output, upper_, stream());
break;
case float64:
cholesky_impl<double>(inputs[0], output, upper_, stream());
break;
default:
throw std::runtime_error(
"[Cholesky::eval_cpu] only supports float32 or float64.");
}
cholesky_impl(inputs[0], output, upper_);
}
} // namespace mlx::core

View File

@@ -11,6 +11,7 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cpu/compiled_preamble.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
@@ -288,6 +289,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
auto& encoder = cpu::get_command_encoder(stream());
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
@@ -298,6 +300,7 @@ void Compiled::eval_cpu(
continue;
}
auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
@@ -356,18 +359,25 @@ void Compiled::eval_cpu(
});
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);
inputs, outputs, inputs_, constant_ids_, contiguous);
for (auto& x : outputs) {
args.push_back(x.data<void>());
encoder.set_output_array(x);
}
Shape out_shape;
if (!contiguous) {
args.push_back((void*)outputs[0].shape().data());
out_shape = outputs[0].shape();
args.push_back((void*)out_shape.data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
fun(args.data());
encoder.dispatch(
[fun,
args = std::move(args),
strides = std::move(strides),
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
}
} // namespace mlx::core

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
@@ -13,19 +14,19 @@ namespace {
template <typename SrcT, typename DstT>
void copy_single(const array& src, array& dst) {
auto val = static_cast<DstT>(src.data<SrcT>()[0]);
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
for (int i = 0; i < dst.size(); ++i) {
dst_ptr[i] = val;
}
auto size = dst.size();
auto val = static_cast<DstT>(src_ptr[0]);
std::fill_n(dst_ptr, size, val);
}
template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
size_t size = src.data_size();
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
auto size = src.data_size();
std::copy(src_ptr, src_ptr + size, dst_ptr);
}
template <typename SrcT, typename DstT, int D>
@@ -60,36 +61,57 @@ void copy_general_general(
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset) {
int64_t o_offset,
const std::optional<array>& dynamic_i_offset,
const std::optional<array>& dynamic_o_offset) {
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
auto i_offset_ptr =
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
auto o_offset_ptr =
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
auto size = src.size();
if (data_shape.empty()) {
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
auto dst_ptr = dst.data<DstT>() + o_offset;
auto val = static_cast<DstT>(*src_ptr);
*dst_ptr = val;
return;
}
auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
int ndim = shape.size();
if (ndim == 1) {
copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
if (ndim < 3) {
if (i_offset_ptr) {
src_ptr += i_offset_ptr[0];
}
if (o_offset_ptr) {
dst_ptr += o_offset_ptr[0];
}
if (ndim == 1) {
copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
}
return;
}
if (i_offset_ptr) {
src_ptr += i_offset_ptr[0];
}
if (o_offset_ptr) {
dst_ptr += o_offset_ptr[0];
}
ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < src.size(); elem += stride) {
for (int64_t elem = 0; elem < size; elem += stride) {
copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
@@ -105,7 +127,15 @@ void copy_general_general(
template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
src,
dst,
src.shape(),
src.strides(),
dst.strides(),
0,
0,
std::nullopt,
std::nullopt);
}
template <typename SrcT, typename DstT>
@@ -116,7 +146,9 @@ void copy_general(
const Strides& i_strides,
const Strides&,
int64_t i_offset,
int64_t o_offset) {
int64_t o_offset,
const std::optional<array>& dynamic_i_offset,
const std::optional<array>& dynamic_o_offset) {
copy_general_general<SrcT, DstT>(
src,
dst,
@@ -124,7 +156,9 @@ void copy_general(
i_strides,
make_contiguous_strides(data_shape),
i_offset,
o_offset);
o_offset,
dynamic_i_offset,
dynamic_o_offset);
}
template <typename SrcT, typename DstT>
@@ -136,7 +170,9 @@ inline void copy_general(const array& src, array& dst) {
src.strides(),
make_contiguous_strides(src.shape()),
0,
0);
0,
std::nullopt,
std::nullopt);
}
template <typename SrcT, typename DstT, typename... Args>
@@ -259,35 +295,27 @@ inline void copy_inplace_dispatch(
} // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) {
copy_inplace_dispatch(src, dst, ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
}
void copy(const array& src, array& dst, CopyType ctype) {
// Allocate the output
switch (ctype) {
case CopyType::Vector:
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
dst.copy_shared_buffer(src);
} else {
auto size = src.data_size();
dst.set_data(
allocator::malloc_or_wait(size * dst.itemsize()),
size,
src.strides(),
src.flags());
}
break;
case CopyType::Scalar:
case CopyType::General:
case CopyType::GeneralGeneral:
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
break;
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
bool donated = set_copy_output_data(src, dst, ctype);
if (donated && src.dtype() == dst.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_inplace(src, dst, ctype);
copy_inplace(src, dst, ctype, stream);
}
void copy_inplace(
@@ -298,24 +326,51 @@ void copy_inplace(
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
break;
case CopyType::Scalar:
case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype);
}
CopyType ctype,
Stream stream,
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
auto weak_copy_if_set = [](auto x) -> std::optional<array> {
if (x) {
return array::unsafe_weak_copy(*x);
} else {
return std::nullopt;
}
};
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
ctype,
dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
dynamic_i_offset,
dynamic_o_offset);
break;
case CopyType::Scalar:
case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype);
}
});
}
} // namespace mlx::core

View File

@@ -2,14 +2,16 @@
#pragma once
#include <optional>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
void copy_inplace(
const array& src,
@@ -19,6 +21,9 @@ void copy_inplace(
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
CopyType ctype,
Stream stream,
const std::optional<array>& dynamic_i_offset = std::nullopt,
const std::optional<array>& dynamic_o_offset = std::nullopt);
} // namespace mlx::core

View File

@@ -0,0 +1,94 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/primitives.h"
namespace mlx::core::distributed {
std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
if (arr.flags().row_contiguous) {
return {arr, false};
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General, stream);
return {arr_copy, true};
}
};
void AllReduce::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto donate_or_copy = [s = stream()](const array& in, array& out) {
if (in.flags().row_contiguous) {
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
return in;
} else {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, s);
out.copy_shared_buffer(arr_copy);
return arr_copy;
}
};
auto in = donate_or_copy(inputs[0], outputs[0]);
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), in, outputs[0], stream());
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
}
void AllGather::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::all_gather(group(), in, outputs[0], stream());
if (copied) {
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporary(in);
}
}
void Send::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
distributed::detail::send(group(), in, dst_, stream());
outputs[0].copy_shared_buffer(inputs[0]);
if (copied) {
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporary(in);
}
}
void Recv::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_, stream());
}
} // namespace mlx::core::distributed

View File

@@ -3,6 +3,7 @@
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
@@ -11,35 +12,76 @@ namespace mlx::core {
namespace {
void ssyevd(
char jobz,
char uplo,
float* a,
int N,
float* w,
float* work,
int lwork,
int* iwork,
int liwork) {
int info;
MLX_LAPACK_FUNC(ssyevd)
(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ a,
/* lda = */ &N,
/* w = */ w,
/* work = */ work,
/* lwork = */ &lwork,
/* iwork = */ iwork,
/* liwork = */ &liwork,
/* info = */ &info);
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
template <typename T>
void eigh_impl(
array& vectors,
array& values,
const std::string& uplo,
bool compute_eigenvectors,
Stream stream) {
auto vec_ptr = vectors.data<T>();
auto eig_ptr = values.data<T>();
char jobz = compute_eigenvectors ? 'V' : 'N';
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(vectors);
encoder.set_output_array(values);
encoder.dispatch([vec_ptr,
eig_ptr,
jobz,
uplo = uplo[0],
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
int lwork = -1;
int liwork = -1;
int info;
{
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
syevd<T>(
&jobz,
&uplo,
&N,
vec_ptr,
&N,
eig_ptr,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
&liwork,
&info);
vec_ptr += N * N;
eig_ptr += N;
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
});
if (!compute_eigenvectors) {
encoder.add_temporary(vectors);
}
}
@@ -55,12 +97,13 @@ void Eigh::eval_cpu(
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
values.set_data(allocator::malloc(values.nbytes()));
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
@@ -78,41 +121,19 @@ void Eigh::eval_cpu(
flags.col_contiguous = true;
}
}
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
vectors.copy_shared_buffer(vectors, strides, flags, vectors.data_size());
}
auto vec_ptr = vectors.data<float>();
auto eig_ptr = values.data<float>();
char jobz = compute_eigenvectors_ ? 'V' : 'N';
auto N = a.shape(-1);
// Work query
int lwork;
int liwork;
{
float work;
int iwork;
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < a.size() / (N * N); ++i) {
ssyevd(
jobz,
uplo_[0],
vec_ptr,
N,
eig_ptr,
static_cast<float*>(work_buf.buffer.raw_ptr()),
lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
liwork);
vec_ptr += N * N;
eig_ptr += N;
switch (a.dtype()) {
case float32:
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_, stream());
break;
case float64:
eigh_impl<double>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error(
"[Eigh::eval_cpu] only supports float32 or float64.");
}
}

View File

@@ -0,0 +1,16 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core::cpu {
CommandEncoder& get_command_encoder(Stream stream) {
static std::unordered_map<int, CommandEncoder> encoder_map;
auto it = encoder_map.find(stream.index);
if (it == encoder_map.end()) {
it = encoder_map.emplace(stream.index, stream).first;
}
return it->second;
}
} // namespace mlx::core::cpu

67
mlx/backend/cpu/encoder.h Normal file
View File

@@ -0,0 +1,67 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <unordered_map>
#include "mlx/array.h"
#include "mlx/scheduler.h"
namespace mlx::core::cpu {
// Number of dispatches per scheduler task
constexpr int DISPATCHES_PER_TASK = 10;
struct CommandEncoder {
CommandEncoder(Stream stream) : stream_(stream) {}
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
CommandEncoder(CommandEncoder&&) = delete;
CommandEncoder& operator=(CommandEncoder&&) = delete;
void set_input_array(const array& a) {}
void set_output_array(array& a) {}
// Hold onto a temporary until any already scheduled tasks which use it as
// an input are complete.
void add_temporary(array arr) {
temporaries_.push_back(std::move(arr));
}
void add_temporaries(std::vector<array> arrays) {
temporaries_.insert(
temporaries_.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
}
std::vector<array>& temporaries() {
return temporaries_;
}
template <class F, class... Args>
void dispatch(F&& f, Args&&... args) {
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
if (num_ops_ == 0) {
scheduler::notify_new_task(stream_);
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
task();
scheduler::notify_task_completion(s);
};
scheduler::enqueue(stream_, std::move(task_wrap));
} else {
scheduler::enqueue(stream_, std::move(task));
}
}
private:
Stream stream_;
std::vector<array> temporaries_;
int num_ops_{0};
};
CommandEncoder& get_command_encoder(Stream stream);
} // namespace mlx::core::cpu

34
mlx/backend/cpu/eval.cpp Normal file
View File

@@ -0,0 +1,34 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/eval.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core::cpu {
void eval(array& arr) {
auto s = arr.primitive().stream();
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_cpu(arr.inputs(), outputs);
}
}
void finalize(
Stream s,
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers) {
auto& encoder = cpu::get_command_encoder(s);
encoder.dispatch([s,
buffers = std::move(retain_buffers),
temps = std::move(encoder.temporaries())]() {});
}
} // namespace mlx::core::cpu

17
mlx/backend/cpu/eval.h Normal file
View File

@@ -0,0 +1,17 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/stream.h"
namespace mlx::core::cpu {
void eval(array& arr);
void finalize(
Stream s,
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers);
} // namespace mlx::core::cpu

View File

@@ -4,6 +4,7 @@
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/allocator.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -21,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
s *= out.itemsize();
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
std::vector<size_t> shape;
if (out.dtype() == float32) {
@@ -38,46 +39,78 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
});
scale /= nelem;
}
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
if (in.dtype() == complex64 && out.dtype() == complex64) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::c2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
encoder.dispatch([shape = std::move(shape),
strides_in = std::move(strides_in),
strides_out = std::move(strides_out),
axes = axes_,
inverse = inverse_,
in_ptr,
out_ptr,
scale]() {
pocketfft::c2c(
shape,
strides_in,
strides_out,
axes,
!inverse,
in_ptr,
out_ptr,
scale);
});
} else if (in.dtype() == float32 && out.dtype() == complex64) {
auto in_ptr = in.data<float>();
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::r2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
encoder.dispatch([shape = std::move(shape),
strides_in = std::move(strides_in),
strides_out = std::move(strides_out),
axes = axes_,
inverse = inverse_,
in_ptr,
out_ptr,
scale]() {
pocketfft::r2c(
shape,
strides_in,
strides_out,
axes,
!inverse,
in_ptr,
out_ptr,
scale);
});
} else if (in.dtype() == complex64 && out.dtype() == float32) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr = out.data<float>();
pocketfft::c2r(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
encoder.dispatch([shape = std::move(shape),
strides_in = std::move(strides_in),
strides_out = std::move(strides_out),
axes = axes_,
inverse = inverse_,
in_ptr,
out_ptr,
scale]() {
pocketfft::c2r(
shape,
strides_in,
strides_out,
axes,
!inverse,
in_ptr,
out_ptr,
scale);
});
} else {
throw std::runtime_error(
"[FFT] Received unexpected input and output type combination.");

View File

@@ -7,14 +7,20 @@ namespace mlx::core {
template <typename T>
void matmul(
const array& a,
const array& b,
array& out,
const T* a,
const T* b,
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides);
} // namespace mlx::core

View File

@@ -9,40 +9,49 @@
namespace mlx::core {
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
uint32_t size_bits = size_of(mlx_dtype) * 8;
switch (kindof(mlx_dtype)) {
case Dtype::Kind::b:
return BNNSDataTypeBoolean;
case Dtype::Kind::u:
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
case Dtype::Kind::i:
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
case Dtype::Kind::f:
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
case Dtype::Kind::V:
return BNNSDataTypeBFloat16;
case Dtype::Kind::c:
throw std::invalid_argument("BNNS does not support complex types");
}
template <typename T>
constexpr BNNSDataType to_bnns_dtype();
template <>
constexpr BNNSDataType to_bnns_dtype<float>() {
return BNNSDataType(BNNSDataTypeFloatBit | 32);
}
template <>
constexpr BNNSDataType to_bnns_dtype<float16_t>() {
return BNNSDataType(BNNSDataTypeFloatBit | 16);
}
template <>
constexpr BNNSDataType to_bnns_dtype<bfloat16_t>() {
return BNNSDataTypeBFloat16;
}
template <typename T>
void matmul_bnns(
const array& a,
const array& b,
array& out,
const T* a,
const T* b,
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,
@@ -113,45 +122,88 @@ void matmul_bnns(
auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
for (int i = 0; i < batch_size; ++i) {
BNNSFilterApplyTwoInput(
bnns_filter,
a.data<uint8_t>() +
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
b.data<uint8_t>() +
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
out.data<uint8_t>() + M * N * i * out.itemsize());
reinterpret_cast<const uint8_t*>(
a + elem_to_loc(M * K * i, a_shape, a_strides)),
reinterpret_cast<const uint8_t*>(
b + elem_to_loc(K * N * i, b_shape, b_strides)),
reinterpret_cast<uint8_t*>(out + M * N * i));
}
BNNSFilterDestroy(bnns_filter);
#pragma GCC diagnostic pop
}
template <>
void matmul<float16_t>(
const array& a,
const array& b,
array& out,
const float16_t* a,
const float16_t* b,
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
matmul_bnns(
a,
b,
out,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape,
a_strides,
b_shape,
b_strides);
}
template <>
void matmul<bfloat16_t>(
const array& a,
const array& b,
array& out,
const bfloat16_t* a,
const bfloat16_t* b,
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
matmul_bnns(
a,
b,
out,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape,
a_strides,
b_shape,
b_strides);
}
} // namespace mlx::core

View File

@@ -8,20 +8,27 @@ namespace mlx::core {
template <>
void matmul<float>(
const array& a,
const array& b,
array& out,
const float* a,
const float* b,
float* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < (a.size() / (M * K)); ++i) {
for (int i = 0; i < batch_size; ++i) {
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
@@ -29,34 +36,40 @@ void matmul<float>(
M,
N,
K,
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
alpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
beta,
out + M * N * i,
ldc);
}
}
template <>
void matmul<double>(
const array& a,
const array& b,
array& out,
const double* a,
const double* b,
double* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < (a.size() / (M * K)); ++i) {
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
@@ -64,15 +77,14 @@ void matmul<double>(
M,
N,
K,
alpha, // alpha
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
alpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
beta, // beta
out.data<double>() + M * N * i,
out.shape(-1) // ldc
);
beta,
out + M * N * i,
ldc);
}
}

View File

@@ -1,21 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const array&,
const array&,
array&,
bool,
bool,
size_t,
size_t,
float,
float) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,21 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const array&,
const array&,
array&,
bool,
bool,
size_t,
size_t,
float,
float) {
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t* a,
const bfloat16_t* b,
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<bfloat16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t* a,
const float16_t* b,
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<float16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,139 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
inline int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
template <int block_size, typename T, typename AccT>
void load_block(
const T* in,
AccT* out,
int M,
int N,
int i,
int j,
bool transpose) {
if (transpose) {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[jj * block_size + ii] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
} else {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[ii * block_size + jj] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
}
}
template <typename T, typename AccT>
void simd_gemm(
const T* a,
const T* b,
T* c,
bool a_trans,
bool b_trans,
int M,
int N,
int K,
float alpha,
float beta) {
constexpr int block_size = 16;
constexpr int simd_size = simd::max_size<AccT>;
static_assert(
(block_size % simd_size) == 0,
"Block size must be divisible by SIMD size");
int last_k_block_size = K - block_size * (K / block_size);
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
for (int i = 0; i < ceildiv(M, block_size); i++) {
for (int j = 0; j < ceildiv(N, block_size); j++) {
AccT c_block[block_size * block_size] = {0.0};
AccT a_block[block_size * block_size];
AccT b_block[block_size * block_size];
int k = 0;
for (; k < K / block_size; k++) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
for (int kk = 0; kk < block_size; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
}
}
}
if (last_k_block_size) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
int kk = 0;
for (; kk < last_k_simd_block; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
for (; kk < last_k_block_size; ++kk) {
c_block[ii * block_size + jj] +=
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
}
}
}
}
// Store
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
if (beta != 0) {
c[c_idx] = static_cast<T>(
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
} else {
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
}
}
}
}
}
}
} // namespace mlx::core

View File

@@ -4,16 +4,17 @@
#include "mlx/backend/common/hadamard.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
// n = 2^k component
template <typename T>
void hadamard_n(array& out, int n, int m, float scale) {
for (int b = 0; b < out.size() / n; b++) {
void hadamard_n(T* out, int n, int m, float scale, size_t size) {
for (int b = 0; b < size / n; b++) {
size_t loc = b * n;
T* data_ptr = out.data<T>() + loc;
T* data_ptr = out + loc;
int h = 1;
int n_over_2 = n / 2;
while (h < n) {
@@ -36,7 +37,7 @@ void hadamard_n(array& out, int n, int m, float scale) {
// m component
template <typename T>
void hadamard_m(array& out, int n, int m, float scale) {
void hadamard_m(T* out, int n, int m, float scale, size_t size) {
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
auto start = 1;
@@ -51,9 +52,9 @@ void hadamard_m(array& out, int n, int m, float scale) {
end = matrix.find('\n', start);
}
for (int b = 0; b < out.size() / m / n; b++) {
for (int b = 0; b < size / m / n; b++) {
size_t loc = b * n * m;
T* data_ptr = out.data<T>() + loc;
T* data_ptr = out + loc;
for (int i = 0; i < n; i++) {
std::vector<float> out(m);
for (int j = 0; j < m; j++) {
@@ -74,12 +75,17 @@ void hadamard_m(array& out, int n, int m, float scale) {
}
template <typename T>
void hadamard(array& out, int n, int m, float scale) {
float n_scale = m > 1 ? 1.0 : scale;
hadamard_n<T>(out, n, m, n_scale);
if (m > 1) {
hadamard_m<T>(out, n, m, scale);
}
void hadamard(array& out, int n, int m, float scale, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
auto out_ptr = out.data<T>();
encoder.dispatch([out_ptr, size = out.size(), n, m, scale]() {
float n_scale = m > 1 ? 1.0 : scale;
hadamard_n<T>(out_ptr, n, m, n_scale, size);
if (m > 1) {
hadamard_m<T>(out_ptr, n, m, scale, size);
}
});
}
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -87,18 +93,26 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Copy input to output
copy(in, out, CopyType::General);
if (in.flags().row_contiguous && in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
copy(
in,
out,
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
}
int axis = out.ndim() - 1;
auto [n, m] = decompose_hadamard(out.shape(axis));
switch (in.dtype()) {
case float32:
return hadamard<float>(out, n, m, scale_);
return hadamard<float>(out, n, m, scale_, stream());
case float16:
return hadamard<float16_t>(out, n, m, scale_);
return hadamard<float16_t>(out, n, m, scale_, stream());
case bfloat16:
return hadamard<bfloat16_t>(out, n, m, scale_);
return hadamard<bfloat16_t>(out, n, m, scale_, stream());
default:
throw std::invalid_argument("[hadamard] Unsupported type.");
}

View File

@@ -8,6 +8,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core {
@@ -21,6 +22,40 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
}
struct None {
template <typename T>
void operator()(T x, T* y) {
(*y) = x;
}
};
struct Sum {
template <typename T>
void operator()(T x, T* y) {
(*y) += x;
}
};
struct Prod {
template <typename T>
void operator()(T x, T* y) {
(*y) *= x;
}
};
struct Max {
template <typename T>
void operator()(T x, T* y) {
(*y) = (*y > x) ? *y : x;
}
};
struct Min {
template <typename T>
void operator()(T x, T* y) {
(*y) = (*y < x) ? *y : x;
}
};
template <typename T, typename IdxT>
void gather(
const array& src,
@@ -73,13 +108,14 @@ void gather(
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
size_t out_idx = 0;
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator src_it;
if (!can_copy && src.ndim() > 0) {
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
}
size_t out_idx = 0;
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) {
@@ -161,46 +197,59 @@ void dispatch_gather(
}
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end());
if (inds.empty()) {
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
return;
std::vector<array> inds;
for (auto it = inputs.begin() + 1; it < inputs.end(); ++it) {
inds.push_back(array::unsafe_weak_copy(*it));
}
switch (inds[0].dtype()) {
case uint8:
auto& encoder = cpu::get_command_encoder(stream());
for (auto& in : inputs) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
encoder.dispatch([axes_ = axes_,
slice_sizes_ = slice_sizes_,
src = array::unsafe_weak_copy(src),
inds = std::move(inds),
out = array::unsafe_weak_copy(out)]() mutable {
if (inds.empty()) {
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint16:
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint32:
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint64:
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
break;
case int8:
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
break;
case int16:
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
break;
case int32:
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
break;
case int64:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
break;
default:
throw std::runtime_error(
"[Gather::eval_cpu] Cannot gather with indices type.");
break;
}
return;
}
switch (inds[0].dtype()) {
case uint8:
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint16:
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint32:
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint64:
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
break;
case int8:
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
break;
case int16:
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
break;
case int32:
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
break;
case int64:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
break;
default:
throw std::runtime_error(
"[Gather::eval_cpu] Cannot gather with indices type.");
break;
}
});
}
template <typename T, typename IdxT>
void gather_axis(
@@ -235,6 +284,7 @@ void gather_axis(
for (int i = axis + 1; i < ind.ndim(); ++i) {
size_post *= ind.shape(i);
}
size_t stride_pre = size_post * ind_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
@@ -304,39 +354,49 @@ void dispatch_gather_axis(
}
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0];
auto& inds = inputs[1];
switch (inds.dtype()) {
case uint8:
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
break;
case uint16:
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
break;
case uint32:
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
break;
case uint64:
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
break;
case int8:
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
break;
case int16:
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
break;
case int32:
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
break;
case int64:
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
break;
default:
throw std::runtime_error(
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
break;
}
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(src);
encoder.set_input_array(inds);
encoder.set_output_array(out);
encoder.dispatch([axis_ = axis_,
src = array::unsafe_weak_copy(src),
inds = array::unsafe_weak_copy(inds),
out = array::unsafe_weak_copy(out)]() mutable {
switch (inds.dtype()) {
case uint8:
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
break;
case uint16:
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
break;
case uint32:
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
break;
case uint64:
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
break;
case int8:
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
break;
case int16:
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
break;
case int32:
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
break;
case int64:
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
break;
default:
throw std::runtime_error(
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
break;
}
});
}
template <typename InT, typename IdxT, typename OpT>
@@ -344,8 +404,7 @@ void scatter(
const array& updates,
array& out,
const std::vector<array>& inds,
const std::vector<int>& axes,
const OpT& op) {
const std::vector<int>& axes) {
int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1;
@@ -361,9 +420,11 @@ void scatter(
ContiguousIterator update_it(updates);
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
auto out_ptr = out.data<InT>();
auto upd_ptr = updates.data<InT>();
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < nind; ++j) {
for (int j = 0; j < inds.size(); ++j) {
auto ax = axes[j];
auto idx_loc = its[j].loc;
its[j].step();
@@ -373,8 +434,7 @@ void scatter(
}
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) {
op(updates.data<InT>()[update_it.loc],
out.data<InT>() + out_offset + out_it.loc);
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
update_it.step();
out_it.step();
}
@@ -392,26 +452,19 @@ void dispatch_scatter_inds(
Scatter::ReduceType rtype) {
switch (rtype) {
case Scatter::None:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
scatter<InT, IdxT, None>(updates, out, indices, axes);
break;
case Scatter::Sum:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
scatter<InT, IdxT, Sum>(updates, out, indices, axes);
break;
case Scatter::Prod:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
scatter<InT, IdxT, Prod>(updates, out, indices, axes);
break;
case Scatter::Max:
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
(*y) = (*y > x) ? *y : x;
});
scatter<InT, IdxT, Max>(updates, out, indices, axes);
break;
case Scatter::Min:
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
(*y) = (*y < x) ? *y : x;
});
scatter<InT, IdxT, Min>(updates, out, indices, axes);
break;
}
}
@@ -463,67 +516,75 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2);
auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
auto& updates = inputs.back();
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype);
copy(src, out, ctype, stream());
switch (src.dtype()) {
case bool_:
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
break;
case uint8:
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint16:
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint32:
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint64:
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
break;
case int8:
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
break;
case int16:
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
break;
case int32:
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
break;
case int64:
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
break;
case float16:
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
break;
case float32:
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
break;
case float64:
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
break;
case complex64:
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
break;
auto& encoder = cpu::get_command_encoder(stream());
std::vector<array> inds;
for (auto it = inputs.begin() + 1; it < inputs.end() - 1; ++it) {
encoder.set_input_array(*it);
inds.push_back(array::unsafe_weak_copy(*it));
}
encoder.set_input_array(updates);
encoder.set_output_array(out);
encoder.dispatch([axes_ = axes_,
reduce_type_ = reduce_type_,
updates = array::unsafe_weak_copy(updates),
inds = std::move(inds),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
break;
case uint8:
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint16:
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint32:
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint64:
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
break;
case int8:
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
break;
case int16:
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
break;
case int32:
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
break;
case int64:
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
break;
case float16:
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
break;
case float32:
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
break;
case float64:
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
break;
case complex64:
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
break;
}
});
}
template <typename T, typename IdxT, typename OpT>
void scatter_axis(
array& out,
const array idx,
const array& upd,
int axis,
const OpT& op) {
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
auto strides = idx.strides();
strides.erase(strides.begin() + axis);
auto shape = idx.shape();
@@ -557,8 +618,9 @@ void scatter_axis(
for (int j = 0; j < idx_ax_size; ++j) {
auto ind_val = offset_neg_idx(
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
dst_ptr + k + ind_val * dst_ax_stride);
OpT{}(
upd_ptr[upd_it.loc + j * upd_ax_stride],
dst_ptr + k + ind_val * dst_ax_stride);
}
idx_it.step();
upd_it.step();
@@ -576,12 +638,10 @@ void dispatch_scatter_axis_op(
ScatterAxis::ReduceType rtype) {
switch (rtype) {
case ScatterAxis::None:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; });
scatter_axis<InT, IdxT, None>(out, idx, updates, axis);
break;
case ScatterAxis::Sum:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; });
scatter_axis<InT, IdxT, Sum>(out, idx, updates, axis);
break;
}
}
@@ -634,53 +694,65 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype);
copy(src, out, ctype, stream());
switch (src.dtype()) {
case bool_:
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
break;
case uint8:
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint16:
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint32:
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint64:
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
break;
case int8:
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
break;
case int16:
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
break;
case int32:
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
break;
case int64:
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
break;
case float16:
dispatch_scatter_axis<float16_t>(out, idx, updates, axis_, reduce_type_);
break;
case float32:
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
break;
case float64:
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
break;
case bfloat16:
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
break;
case complex64:
dispatch_scatter_axis<complex64_t>(
out, idx, updates, axis_, reduce_type_);
break;
}
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(idx);
encoder.set_input_array(updates);
encoder.set_output_array(out);
encoder.dispatch([axis_ = axis_,
reduce_type_ = reduce_type_,
idx = array::unsafe_weak_copy(idx),
updates = array::unsafe_weak_copy(updates),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
break;
case uint8:
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint16:
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint32:
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint64:
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
break;
case int8:
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
break;
case int16:
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
break;
case int32:
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
break;
case int64:
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
break;
case float16:
dispatch_scatter_axis<float16_t>(
out, idx, updates, axis_, reduce_type_);
break;
case float32:
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
break;
case float64:
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
break;
case bfloat16:
dispatch_scatter_axis<bfloat16_t>(
out, idx, updates, axis_, reduce_type_);
break;
case complex64:
dispatch_scatter_axis<complex64_t>(
out, idx, updates, axis_, reduce_type_);
break;
}
});
}
} // namespace mlx::core

View File

@@ -2,47 +2,37 @@
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info;
MLX_LAPACK_FUNC(strtri)
(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
return info;
}
namespace mlx::core {
void general_inv(array& inv, int N, int i) {
template <typename T>
void general_inv(T* inv, int N) {
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
// Compute LU factorization.
sgetrf_(
getrf<T>(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* a = */ inv,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
ss << "[Inverse::eval_cpu] LU factorization failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
T workspace_size = 0;
// Compute workspace size.
sgetri_(
getri<T>(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
@@ -53,68 +43,118 @@ void general_inv(array& inv, int N, int i) {
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
ss << "[Inverse::eval_cpu] LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Compute inverse.
sgetri_(
getri<T>(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* a = */ inv,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
ss << "[Inverse::eval_cpu] inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void tri_inv(array& inv, int N, int i, bool upper) {
template <typename T>
void tri_inv(T* inv, int N, bool upper) {
const char uplo = upper ? 'L' : 'U';
const char diag = 'N';
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
int info;
trtri<T>(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ inv,
/* lda = */ &N,
/* info = */ &info);
// zero out the other triangle
if (upper) {
for (int i = 0; i < N; i++) {
std::fill(inv, inv + i, 0.0f);
inv += N;
}
} else {
for (int i = 0; i < N; i++) {
std::fill(inv + i + 1, inv + N, 0.0f);
inv += N;
}
}
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: triangular inversion failed with error code " << info;
ss << "[Inverse::eval_cpu] triangular inversion failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
}
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
template <typename T>
void inverse_impl(
const array& a,
array& inv,
bool tri,
bool upper,
Stream stream) {
// Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see
// https://math.stackexchange.com/a/340234):
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output.
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy(
a,
inv,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
for (int i = 0; i < num_matrices; i++) {
if (tri) {
tri_inv(inv, N, i, upper);
} else {
general_inv(inv, N, i);
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(inv);
auto inv_ptr = inv.data<T>();
if (tri) {
encoder.dispatch([inv_ptr, N, num_matrices, upper]() {
for (int i = 0; i < num_matrices; i++) {
tri_inv<T>(inv_ptr + N * N * i, N, upper);
}
});
} else {
encoder.dispatch([inv_ptr, N, num_matrices]() {
for (int i = 0; i < num_matrices; i++) {
general_inv<T>(inv_ptr + N * N * i, N);
}
});
}
}
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32.");
switch (inputs[0].dtype()) {
case float32:
inverse_impl<float>(inputs[0], output, tri_, upper_, stream());
break;
case float64:
inverse_impl<double>(inputs[0], output, tri_, upper_, stream());
break;
default:
throw std::runtime_error(
"[Inverse::eval_cpu] only supports float32 or float64.");
}
inverse_impl(inputs[0], output, tri_, upper_);
}
} // namespace mlx::core

View File

@@ -31,3 +31,22 @@
#define MLX_LAPACK_FUNC(f) f##_
#endif
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, double>) { \
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_TYPES(geqrf)
INSTANTIATE_LAPACK_TYPES(orgqr)
INSTANTIATE_LAPACK_TYPES(syevd)
INSTANTIATE_LAPACK_TYPES(potrf)
INSTANTIATE_LAPACK_TYPES(gesvdx)
INSTANTIATE_LAPACK_TYPES(getrf)
INSTANTIATE_LAPACK_TYPES(getri)
INSTANTIATE_LAPACK_TYPES(trtri)

View File

@@ -0,0 +1,140 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void logsumexp(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* current_in_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
normalizer += _exp;
current_in_ptr++;
}
// Normalize
*out_ptr = std::isinf(maximum)
? static_cast<T>(maximum)
: static_cast<T>(std::log(normalizer) + maximum);
}
});
}
} // namespace
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto s = stream();
auto& encoder = cpu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
switch (in.dtype()) {
case float32:
logsumexp<float, float>(in, out, stream());
break;
case float16:
logsumexp<float16_t, float>(in, out, stream());
break;
case bfloat16:
logsumexp<bfloat16_t, float>(in, out, stream());
break;
case float64:
logsumexp<double, double>(in, out, stream());
break;
default:
throw std::runtime_error(
"[logsumexp] only supports floating point types");
break;
}
}
} // namespace mlx::core

View File

@@ -4,18 +4,22 @@
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
void lu_factor_impl(
template <typename T>
void luf_impl(
const array& a,
array& lu,
array& pivots,
array& row_indices) {
array& row_indices,
Stream stream) {
int M = a.shape(-2);
int N = a.shape(-1);
int K = std::min(M, N);
// Copy a into lu and make it col contiguous
auto ndim = lu.ndim();
@@ -26,63 +30,91 @@ void lu_factor_impl(
auto strides = lu.strides();
strides[ndim - 1] = M;
strides[ndim - 2] = 1;
lu.set_data(
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace(
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
a,
lu,
a.shape(),
a.strides(),
strides,
0,
0,
CopyType::GeneralGeneral,
stream);
auto a_ptr = lu.data<float>();
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
auto a_ptr = lu.data<T>();
pivots.set_data(allocator::malloc(pivots.nbytes()));
row_indices.set_data(allocator::malloc(row_indices.nbytes()));
auto pivots_ptr = pivots.data<uint32_t>();
auto row_indices_ptr = row_indices.data<uint32_t>();
int info;
size_t num_matrices = a.size() / (M * N);
for (size_t i = 0; i < num_matrices; ++i) {
// Compute LU factorization of A
MLX_LAPACK_FUNC(sgetrf)
(/* m */ &M,
/* n */ &N,
/* a */ a_ptr,
/* lda */ &M,
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
/* info */ &info);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(lu);
encoder.set_output_array(pivots);
encoder.set_output_array(row_indices);
if (info != 0) {
std::stringstream ss;
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
<< ((info > 0) ? " because matrix is singular"
: " because argument had an illegal value");
throw std::runtime_error(ss.str());
}
encoder.dispatch(
[a_ptr, pivots_ptr, row_indices_ptr, num_matrices, M, N, K]() mutable {
int info;
for (size_t i = 0; i < num_matrices; ++i) {
// Compute LU factorization of A
getrf<T>(
/* m */ &M,
/* n */ &N,
/* a */ a_ptr,
/* lda */ &M,
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
/* info */ &info);
// Subtract 1 to get 0-based index
for (int j = 0; j < pivots.shape(-1); ++j) {
pivots_ptr[j]--;
row_indices_ptr[j] = j;
}
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
auto piv = pivots_ptr[j];
auto t1 = row_indices_ptr[piv];
auto t2 = row_indices_ptr[j];
row_indices_ptr[j] = t1;
row_indices_ptr[piv] = t2;
}
if (info != 0) {
std::stringstream ss;
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
<< ((info > 0) ? " because matrix is singular"
: " because argument had an illegal value");
throw std::runtime_error(ss.str());
}
// Advance pointers to the next matrix
a_ptr += M * N;
pivots_ptr += pivots.shape(-1);
row_indices_ptr += pivots.shape(-1);
}
// Subtract 1 to get 0-based index
int j = 0;
for (; j < K; ++j) {
pivots_ptr[j]--;
row_indices_ptr[j] = j;
}
for (; j < M; ++j) {
row_indices_ptr[j] = j;
}
for (int j = K - 1; j >= 0; --j) {
auto piv = pivots_ptr[j];
auto t1 = row_indices_ptr[piv];
auto t2 = row_indices_ptr[j];
row_indices_ptr[j] = t1;
row_indices_ptr[piv] = t2;
}
// Advance pointers to the next matrix
a_ptr += M * N;
pivots_ptr += K;
row_indices_ptr += M;
}
});
}
void LUF::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
switch (inputs[0].dtype()) {
case float32:
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2], stream());
break;
case float64:
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2], stream());
break;
default:
throw std::runtime_error(
"[LUF::eval_cpu] only supports float32 or float64.");
}
}
} // namespace mlx::core

View File

@@ -18,13 +18,16 @@ if [ "$CLANG" = "TRUE" ]; then
#include <complex>
#include <cstdint>
#include <vector>
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
#include <arm_fp16.h>
#endif
EOM
CC_FLAGS="-arch ${ARCH}"
CC_FLAGS="-arch ${ARCH} -nobuiltininc -nostdinc"
else
CC_FLAGS="-std=c++17"
fi
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 2>/dev/null)
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E -P "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() {

View File

@@ -5,6 +5,7 @@
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
@@ -58,42 +59,42 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose =
[](const array& arr, bool do_copy, bool expand_all = false) {
[s = stream()](const array& arr, bool do_copy, bool expand_all = false) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(false, stx, arr_copy);
copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(false, stx, arr_copy, true);
}
return std::make_tuple(false, stx, arr);
return std::make_tuple(false, stx, arr, false);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(true, sty, arr_copy);
copy(arr, arr_copy, CopyType::Vector, s);
return std::make_tuple(true, sty, arr_copy, true);
}
return std::make_tuple(true, sty, arr);
return std::make_tuple(true, sty, arr, false);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
copy(arr, arr_copy, CopyType::General, s);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
return std::make_tuple(false, stx, arr_copy, true);
}
};
bool has_op_mask = inputs.size() > 3;
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
auto [a_transposed, lda, a] =
auto [a_transposed, lda, a, a_copied] =
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
auto [b_transposed, ldb, b] =
auto [b_transposed, ldb, b, b_copied] =
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
size_t M = a.shape(-2);
@@ -104,31 +105,39 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return;
}
auto& encoder = cpu::get_command_encoder(stream());
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<void>(), nbytes = out.nbytes()]() {
std::memset(out_ptr, 0, nbytes);
});
return;
}
auto mask_array = [](const array& mask,
auto mask_array = [](const void* mask,
float* data,
int block_size,
int batch_idx,
int X,
int Y,
size_t X_data_str,
size_t Y_data_str) {
size_t Y_data_str,
const Shape& mask_shape,
const Strides& mask_strides,
bool is_bool) {
auto ndim = mask_shape.size();
auto mask_offset = elem_to_loc(
mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
mask_shape[ndim - 1] * mask_shape[ndim - 2] * batch_idx,
mask_shape,
mask_strides);
auto X_mask_str = mask.strides()[mask.ndim() - 2];
auto Y_mask_str = mask.strides()[mask.ndim() - 1];
auto X_mask_str = mask_strides[ndim - 2];
auto Y_mask_str = mask_strides[ndim - 1];
if (mask.dtype() == bool_) {
if (is_bool) {
return mask_matrix(
data,
mask.data<bool>(),
static_cast<const bool*>(mask),
block_size,
X,
Y,
@@ -140,7 +149,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
} else {
return mask_matrix(
data,
mask.data<float>(),
static_cast<const float*>(mask),
block_size,
X,
Y,
@@ -152,61 +161,155 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
}
};
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
// Adjust pointer
float* ai =
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
float* bi =
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides());
float* ci = out.data<float>() + M * N * i;
encoder.set_input_array(a);
encoder.set_input_array(b);
const void* a_mask_ptr;
const void* b_mask_ptr;
const void* out_mask_ptr;
Shape a_mask_shape;
Shape b_mask_shape;
Shape out_mask_shape;
Strides a_mask_strides;
Strides b_mask_strides;
Strides out_mask_strides;
bool a_mask_bool;
bool b_mask_bool;
bool out_mask_bool;
if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2];
auto& b_mask = inputs[inputs.size() - 1];
a_mask_ptr = a_mask.data<void>();
b_mask_ptr = b_mask.data<void>();
a_mask_shape = a_mask.shape();
b_mask_shape = b_mask.shape();
a_mask_strides = a_mask.strides();
b_mask_strides = b_mask.strides();
a_mask_bool = (a_mask.dtype() == bool_);
b_mask_bool = (b_mask.dtype() == bool_);
encoder.set_input_array(a_mask);
encoder.set_input_array(b_mask);
}
if (has_out_mask) {
auto& out_mask = inputs[2];
out_mask_ptr = out_mask.data<void>();
out_mask_bool = (out_mask.dtype() == bool_);
encoder.set_input_array(out_mask);
out_mask_shape = out_mask.shape();
out_mask_strides = out_mask.strides();
}
encoder.set_output_array(out);
auto a_ptr = a.data<float>();
auto b_ptr = b.data<float>();
auto out_ptr = out.data<float>();
size_t num_matrices = out.size() / (M * size_t(N));
auto ldc = out.shape(-1);
// Zero out blocks in a and b if needed
if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2];
mask_array(
a_mask,
ai,
block_size_,
i,
encoder.dispatch([a_ptr,
b_ptr,
out_ptr,
a_mask_ptr,
b_mask_ptr,
out_mask_ptr,
has_op_mask,
has_out_mask,
block_size = block_size_,
num_matrices,
M,
N,
K,
a_transposed = a_transposed,
b_transposed = b_transposed,
lda = lda,
ldb = ldb,
ldc,
a_shape = a.shape(),
a_strides = a.strides(),
b_shape = b.shape(),
b_strides = b.strides(),
a_mask_shape = std::move(a_mask_shape),
b_mask_shape = std::move(b_mask_shape),
out_mask_shape = std::move(out_mask_shape),
a_mask_strides = std::move(a_mask_strides),
b_mask_strides = std::move(b_mask_strides),
out_mask_strides = std::move(out_mask_strides),
mask_array,
a_mask_bool,
b_mask_bool,
out_mask_bool]() {
for (int i = 0; i < num_matrices; ++i) {
// Adjust pointer
float* ai = a_ptr + elem_to_loc(M * K * i, a_shape, a_strides);
float* bi = b_ptr + elem_to_loc(K * N * i, b_shape, b_strides);
float* ci = out_ptr + M * N * i;
// Zero out blocks in a and b if needed
if (has_op_mask) {
mask_array(
a_mask_ptr,
ai,
block_size,
i,
M,
K,
a_transposed ? 1 : lda,
a_transposed ? lda : 1,
a_mask_shape,
a_mask_strides,
a_mask_bool);
mask_array(
b_mask_ptr,
bi,
block_size,
i,
K,
N,
b_transposed ? 1 : ldb,
b_transposed ? ldb : 1,
b_mask_shape,
b_mask_strides,
b_mask_bool);
}
// Do matmul
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
K,
a_transposed ? 1 : lda,
a_transposed ? lda : 1);
auto& b_mask = inputs[inputs.size() - 1];
mask_array(
b_mask,
bi,
block_size_,
i,
K,
N,
b_transposed ? 1 : ldb,
b_transposed ? ldb : 1);
}
K,
1.0, // alpha
ai,
lda,
bi,
ldb,
0.0, // beta
ci,
ldc);
// Do matmul
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0, // alpha
ai,
lda,
bi,
ldb,
0.0, // beta
ci,
out.shape(-1) // ldc
);
// Zero out blocks in out
if (has_out_mask) {
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
// Zero out blocks in out
if (has_out_mask) {
mask_array(
out_mask_ptr,
ci,
block_size,
i,
M,
N,
N,
1,
out_mask_shape,
out_mask_strides,
out_mask_bool);
}
}
});
if (a_copied) {
encoder.add_temporary(a);
}
if (b_copied) {
encoder.add_temporary(b);
}
}
@@ -215,12 +318,13 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[GatherMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose = [](const array& arr) {
std::vector<array> temps;
auto check_transpose = [s = stream(), &temps](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
@@ -228,10 +332,10 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, s);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
return std::make_tuple(false, stx, temps.back());
}
};
@@ -246,8 +350,12 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return;
}
auto& encoder = cpu::get_command_encoder(stream());
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<float>(), size = out.size()]() {
std::fill_n(out_ptr, size, 0);
});
return;
}
@@ -272,29 +380,61 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
auto ldc = out.shape(-1);
for (int i = 0; i < batch_size_out; i++) {
// Get index
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
encoder.dispatch([a_ptr = a.data<float>(),
b_ptr = b.data<float>(),
out_ptr = out.data<float>(),
M,
N,
K,
lda = lda,
ldb = ldb,
a_transposed = a_transposed,
b_transposed = b_transposed,
ldc,
lhs_indices_ptr,
rhs_indices_ptr,
lhs_indices_shape = lhs_indices.shape(),
lhs_indices_strides = lhs_indices.strides(),
rhs_indices_shape = rhs_indices.shape(),
rhs_indices_strides = rhs_indices.strides(),
batch_size_out,
matrix_stride_out,
batch_shape_A = std::move(batch_shape_A),
batch_shape_B = std::move(batch_shape_B),
batch_strides_A = std::move(batch_strides_A),
batch_strides_B = std::move(batch_strides_B)]() {
for (int i = 0; i < batch_size_out; i++) {
// Get index
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(
i, lhs_indices_shape, lhs_indices_strides)];
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(
i, rhs_indices_shape, rhs_indices_strides)];
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda,
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb,
0.0f, // beta
out.data<float>() + matrix_stride_out * i,
out.shape(-1) // ldc
);
}
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a_ptr + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda,
b_ptr + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb,
0.0f, // beta
out_ptr + matrix_stride_out * i,
ldc);
}
});
encoder.add_temporaries(std::move(temps));
}
} // namespace mlx::core

View File

@@ -3,18 +3,76 @@
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/primitives.h"
namespace mlx::core {
template <typename T>
void matmul_dispatch(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta,
Stream stream) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
T* out_ptr = out.data<T>();
size_t ldc = out.shape(-1);
size_t batch_size = a.size() / (a.shape(-2) * a.shape(-1));
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a_ptr,
b_ptr,
out_ptr,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape = a.shape(),
a_strides = a.strides(),
b_shape = b.shape(),
b_strides = b.strides()]() {
matmul<T>(
a_ptr,
b_ptr,
out_ptr,
a_transposed,
b_transposed,
lda,
ldb,
ldc,
alpha,
beta,
batch_size,
a_shape,
a_strides,
b_shape,
b_strides);
});
}
void matmul_general(
const array& a_pre,
const array& b_pre,
array& out,
Stream stream,
float alpha = 1.0f,
float beta = 0.0f) {
auto check_transpose = [](const array& arr) {
std::vector<array> temps;
auto check_transpose = [stream, &temps](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
@@ -22,10 +80,10 @@ void matmul_general(
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, stream);
stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
return std::make_tuple(false, stx, temps.back());
}
};
@@ -39,28 +97,34 @@ void matmul_general(
}
if (out.dtype() == float32) {
matmul<float>(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
matmul_dispatch<float>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == float16) {
matmul<float16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
matmul_dispatch<float16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == bfloat16) {
matmul<bfloat16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
matmul_dispatch<bfloat16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == float64) {
matmul<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
}
cpu::get_command_encoder(stream).add_temporaries(std::move(temps));
}
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
if (inputs[0].shape(-1) == 0) {
std::memset(out.data<void>(), 0, out.nbytes());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<void>(), nbytes = out.nbytes()]() {
std::memset(out_ptr, 0, nbytes);
});
return;
}
return matmul_general(inputs[0], inputs[1], out);
matmul_general(inputs[0], inputs[1], out, stream());
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -74,9 +138,9 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy(c, out, ctype);
copy(c, out, ctype, stream());
return matmul_general(inputs[0], inputs[1], out, alpha_, beta_);
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
}
} // namespace mlx::core

View File

@@ -7,11 +7,11 @@
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/backend/common/load.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/arange.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/threefry.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -21,40 +21,59 @@ namespace mlx::core {
void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
out.set_data(allocator::malloc(out.nbytes()));
copy_inplace(in, out, CopyType::General, out.primitive().stream());
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
int64_t compute_dynamic_offset(
static std::pair<array, bool> compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes) {
auto compute_offset = [&strides, &axes](const auto* indices) {
int64_t offset = 0;
for (int i = 0; i < axes.size(); ++i) {
offset += indices[i] * strides[axes[i]];
}
return offset;
};
const std::vector<int>& axes,
Stream stream) {
array offset({1}, int64, nullptr, {});
bool donate = indices.is_donatable() &&
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc(offset.itemsize()));
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(indices);
encoder.set_output_array(offset);
auto compute_offset =
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
int64_t offset_ = 0;
for (int i = 0; i < axes.size(); ++i) {
offset_ += indices[i] * strides[axes[i]];
}
offset[0] = offset_;
};
switch (indices.dtype()) {
case int8:
case uint8:
return compute_offset(indices.data<uint8_t>());
encoder.dispatch(compute_offset, indices.data<uint8_t>());
break;
case int16:
case uint16:
return compute_offset(indices.data<uint16_t>());
encoder.dispatch(compute_offset, indices.data<uint16_t>());
break;
case int32:
case uint32:
return compute_offset(indices.data<uint32_t>());
encoder.dispatch(compute_offset, indices.data<uint32_t>());
break;
case int64:
case uint64:
return compute_offset(indices.data<uint64_t>());
encoder.dispatch(compute_offset, indices.data<uint64_t>());
break;
default:
throw std::runtime_error("Invalid indices type.");
}
return {offset, donate};
}
void AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -104,14 +123,59 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
}
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
arange(inputs, out, start_, step_);
assert(inputs.size() == 0);
out.set_data(allocator::malloc(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
break;
case uint8:
arange<uint8_t>(start_, start_ + step_, out, out.size(), stream());
break;
case uint16:
arange<uint16_t>(start_, start_ + step_, out, out.size(), stream());
break;
case uint32:
arange<uint32_t>(start_, start_ + step_, out, out.size(), stream());
break;
case uint64:
arange<uint64_t>(start_, start_ + step_, out, out.size(), stream());
break;
case int8:
arange<int8_t>(start_, start_ + step_, out, out.size(), stream());
break;
case int16:
arange<int16_t>(start_, start_ + step_, out, out.size(), stream());
break;
case int32:
arange<int32_t>(start_, start_ + step_, out, out.size(), stream());
break;
case int64:
arange<int64_t>(start_, start_ + step_, out, out.size(), stream());
break;
case float16:
arange<float16_t>(start_, start_ + step_, out, out.size(), stream());
break;
case float32:
arange<float>(start_, start_ + step_, out, out.size(), stream());
break;
case float64:
arange<double>(start_, start_ + step_, out, out.size(), stream());
break;
case bfloat16:
arange<bfloat16_t>(start_, start_ + step_, out, out.size(), stream());
break;
case complex64:
arange<complex64_t>(start_, start_ + step_, out, out.size(), stream());
break;
}
}
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
copy(in, out, ctype, stream());
}
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -122,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
@@ -134,18 +198,20 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral);
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
}
}
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
copy(in, out, CopyType::General, stream());
}
}
@@ -169,14 +235,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
} else {
ctype = CopyType::General;
}
copy(in, out, ctype);
}
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
copy(in, out, ctype, stream());
}
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -192,7 +251,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy(val, out, CopyType::Scalar);
copy(val, out, CopyType::Scalar, stream());
// Find offset for start of input values
size_t data_offset = 0;
@@ -207,7 +266,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
}
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -219,43 +278,53 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>();
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);
// Get ith key
auto kidx = 2 * i;
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides());
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides());
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(inputs[0]);
encoder.set_output_array(out);
encoder.dispatch([kptr,
cptr,
bytes_per_key,
num_keys,
kshape = keys.shape(),
kstrides = keys.strides()]() mutable {
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);
// Get ith key
auto kidx = 2 * i;
auto k1_elem = elem_to_loc(kidx, kshape, kstrides);
auto k2_elem = elem_to_loc(kidx + 1, kshape, kstrides);
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
for (; count.first + 1 < half_size; count.first++, count.second++) {
std::tie(ptr[count.first], ptr[count.second]) =
random::threefry2x32_hash(key, count);
}
if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) {
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
for (; count.first + 1 < half_size; count.first++, count.second++) {
std::tie(ptr[count.first], ptr[count.second]) =
random::threefry2x32_hash(key, count);
}
if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) {
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
}
if (!even) {
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
}
}
if (!even) {
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
}
}
});
}
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -268,17 +337,24 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
return;
}
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_);
out.set_data(allocator::malloc(out.nbytes()));
auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ i_offset,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral);
/* CopyType ctype = */ CopyType::GeneralGeneral,
stream(),
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
if (!donated) {
cpu::get_command_encoder(stream()).add_temporary(std::move(in_offset));
}
}
void DynamicSliceUpdate::eval_cpu(
@@ -296,9 +372,10 @@ void DynamicSliceUpdate::eval_cpu(
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_);
auto [out_offset, donated] =
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
@@ -306,8 +383,14 @@ void DynamicSliceUpdate::eval_cpu(
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ o_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral,
stream(),
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
if (!donated) {
cpu::get_command_encoder(stream()).add_temporary(std::move(out_offset));
}
}
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -329,7 +412,7 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] =
@@ -344,7 +427,8 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
/* CopyType ctype = */ CopyType::GeneralGeneral,
stream());
}
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -368,13 +452,13 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
} else {
auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
tmp.set_data(allocator::malloc(tmp.nbytes()));
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_inplace(in_tmp, tmp, CopyType::General);
copy_inplace(in_tmp, tmp, CopyType::General, stream());
} else {
copy_inplace(in, tmp, CopyType::General);
copy_inplace(in, tmp, CopyType::General, stream());
}
auto flags = out.flags();
@@ -382,7 +466,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
}
}

View File

@@ -2,53 +2,21 @@
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
template <typename T>
struct lpack;
template <>
struct lpack<float> {
static void xgeqrf(
const int* m,
const int* n,
float* a,
const int* lda,
float* tau,
float* work,
const int* lwork,
int* info) {
sgeqrf_(m, n, a, lda, tau, work, lwork, info);
}
static void xorgqr(
const int* m,
const int* n,
const int* k,
float* a,
const int* lda,
const float* tau,
float* work,
const int* lwork,
int* info) {
sorgqr_(m, n, k, a, lda, tau, work, lwork, info);
}
};
template <typename T>
void qrf_impl(const array& a, array& q, array& r) {
void qrf_impl(const array& a, array& q, array& r, Stream stream) {
const int M = a.shape(-2);
const int N = a.shape(-1);
const int lda = M;
size_t num_matrices = a.size() / (M * N);
int num_reflectors = std::min(M, N);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
// Copy A to inplace input and make it col-contiguous
array in(a.shape(), float32, nullptr, {});
array in(a.shape(), a.dtype(), nullptr, {});
auto flags = in.flags();
// Copy the input to be column contiguous
@@ -57,105 +25,123 @@ void qrf_impl(const array& a, array& q, array& r) {
auto strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral);
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc(q.nbytes()));
r.set_data(allocator::malloc(r.nbytes()));
T optimal_work;
int lwork = -1;
int info;
auto in_ptr = in.data<T>();
auto r_ptr = r.data<T>();
auto q_ptr = q.data<T>();
// Compute workspace size
lpack<T>::xgeqrf(
&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
encoder.set_input_array(in);
encoder.set_output_array(q);
encoder.set_output_array(r);
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
int num_reflectors = std::min(M, N);
auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
T optimal_work;
int lwork = -1;
int info;
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Solve
lpack<T>::xgeqrf(
&M,
&N,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
allocator::free(work);
// Compute workspace size
geqrf<T>(&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
r.set_data(allocator::malloc_or_wait(r.nbytes()));
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc(sizeof(T) * lwork);
for (int i = 0; i < num_matrices; ++i) {
/// num_reflectors x N
for (int j = 0; j < r.shape(-2); ++j) {
for (int k = 0; k < j; ++k) {
r.data<T>()[i * N * num_reflectors + j * N + k] = 0;
}
for (int k = j; k < r.shape(-1); ++k) {
r.data<T>()[i * N * num_reflectors + j * N + k] =
in.data<T>()[i * N * M + j + k * M];
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Solve
geqrf<T>(
&M,
&N,
in_ptr + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
allocator::free(work);
for (int i = 0; i < num_matrices; ++i) {
/// num_reflectors x N
for (int j = 0; j < num_reflectors; ++j) {
for (int k = 0; k < j; ++k) {
r_ptr[i * N * num_reflectors + j * N + k] = 0;
}
for (int k = j; k < N; ++k) {
r_ptr[i * N * num_reflectors + j * N + k] =
in_ptr[i * N * M + j + k * M];
}
}
}
}
// Get work size
lwork = -1;
lpack<T>::xorgqr(
&M,
&num_reflectors,
&num_reflectors,
nullptr,
&lda,
nullptr,
&optimal_work,
&lwork,
&info);
lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
lpack<T>::xorgqr(
// Get work size
lwork = -1;
orgqr<T>(
&M,
&num_reflectors,
&num_reflectors,
in.data<float>() + M * N * i,
nullptr,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
nullptr,
&optimal_work,
&lwork,
&info);
}
lwork = optimal_work;
work = allocator::malloc(sizeof(T) * lwork);
q.set_data(allocator::malloc_or_wait(q.nbytes()));
for (int i = 0; i < num_matrices; ++i) {
// M x num_reflectors
for (int j = 0; j < q.shape(-2); ++j) {
for (int k = 0; k < q.shape(-1); ++k) {
q.data<T>()[i * M * num_reflectors + j * num_reflectors + k] =
in.data<T>()[i * N * M + j + k * M];
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
orgqr<T>(
&M,
&num_reflectors,
&num_reflectors,
in_ptr + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
for (int i = 0; i < num_matrices; ++i) {
// M x num_reflectors
for (int j = 0; j < M; ++j) {
for (int k = 0; k < num_reflectors; ++k) {
q_ptr[i * M * num_reflectors + j * num_reflectors + k] =
in_ptr[i * N * M + j + k * M];
}
}
}
}
// Cleanup
allocator::free(work);
allocator::free(tau);
// Cleanup
allocator::free(work);
allocator::free(tau);
});
encoder.add_temporary(in);
}
void QRF::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[QRF::eval] only supports float32.");
switch (inputs[0].dtype()) {
case float32:
qrf_impl<float>(inputs[0], outputs[0], outputs[1], stream());
break;
case float64:
qrf_impl<double>(inputs[0], outputs[0], outputs[1], stream());
break;
default:
throw std::runtime_error(
"[QRF::eval_cpu] only supports float32 or float64.");
}
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
}
} // namespace mlx::core

View File

@@ -3,6 +3,7 @@
#include <cassert>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
@@ -164,8 +165,8 @@ simd::Simd<uint32_t, S> extract_bits_simd(const uint32_t* w) {
} else if constexpr (bits == 8 && S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 8, 16, 24, 0, 8, 16, 24}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto l = simd::Simd<uint32_t, 4>(*w++);
auto r = simd::Simd<uint32_t, 4>(*w);
auto l = simd::Simd<uint32_t, S / 2>(*w++);
auto r = simd::Simd<uint32_t, S / 2>(*w);
wi = simd::Simd<uint32_t, S>(l, r);
wi = wi >> shifts;
wi = wi & bitmask;
@@ -316,7 +317,8 @@ void _qmm_dispatch_typed(
}
}
void _qmm_dispatch(
template <typename T>
void _qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
@@ -328,63 +330,61 @@ void _qmm_dispatch(
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
for (int i = 0; i < batch_size; i++) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float>() + elem_to_loc(i * g_els, scales),
biases.data<float>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()),
M,
N,
K,
bits,
group_size,
transposed_w);
}
}
void _bs_qmm_dispatch(
void _qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
int bits,
int group_size,
bool transposed_w) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out, x, w, scales, biases, bits, group_size, transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
template <typename T>
void _bs_qmm_dispatch_typed(
array& out,
const array& x,
const array& w,
@@ -402,60 +402,90 @@ void _bs_qmm_dispatch(
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
biases_ptr +
elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()),
M,
N,
K,
bits,
group_size,
transposed_w);
}
}
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
void _bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
const array& lhs_indices,
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w) {
switch (x.dtype()) {
case float32:
_bs_qmm_dispatch_typed<float>(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
bits,
group_size,
transposed_w);
break;
case float16:
_bs_qmm_dispatch_typed<float16_t>(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_bs_qmm_dispatch_typed<bfloat16_t>(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
@@ -469,13 +499,14 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto ensure_row_contiguous = [](const array& arr) {
std::vector<array> temps;
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, s);
return temps.back();
}
};
@@ -484,8 +515,25 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -498,15 +546,17 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto ensure_row_contiguous_last_dims = [](const array& arr) {
std::vector<array> temps;
auto ensure_row_contiguous_last_dims = [s = stream(),
&temps](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
copy(arr, temps.back(), CopyType::General, s);
return temps.back();
}
};
@@ -515,71 +565,86 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
}
template <typename T, typename U>
void quantize(
const array& w_,
array& out_,
array& scales_,
array& biases_,
const T* w,
U* out,
T* scales,
T* biases,
int bits,
int group_size) {
const T* w = w_.data<T>();
int group_size,
size_t w_size) {
float n_bins = (1 << bits) - 1;
float eps = 1e-7;
auto out = out_.data<U>();
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
T n_bins = (1 << bits) - 1;
T eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_.size() / group_size;
size_t n_groups = w_size / group_size;
for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size;
T w_min = std::numeric_limits<float>::infinity();
T w_max = -w_min;
float w_min = std::numeric_limits<float>::infinity();
float w_max = -w_min;
for (int j = 0; j < group_size; ++j) {
w_max = std::max(w_max, w[w_idx + j]);
w_min = std::min(w_min, w[w_idx + j]);
w_max = std::max(w_max, (float)w[w_idx + j]);
w_min = std::min(w_min, (float)w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
T scale = std::max(T((w_max - w_min) / n_bins), eps);
float scale = std::max((w_max - w_min) / n_bins, eps);
scale = mask ? scale : -scale;
auto edge = mask ? w_min : w_max;
auto q0 = std::rint(edge / scale);
if (q0 == 0) {
scales[i] = scale;
biases[i] = 0;
} else {
scales[i] = edge / q0;
biases[i] = edge;
float edge = mask ? w_min : w_max;
float q0 = std::rint(edge / scale);
float bias = 0;
if (q0 != 0) {
scale = edge / q0;
bias = edge;
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
T w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - biases[i]) / scales[i]);
w_el = std::min(std::max(w_el, T(0)), n_bins);
float w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - bias) / scale);
w_el = std::min(std::max(w_el, 0.0f), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
}
if (power_of_2_bits) {
@@ -590,53 +655,91 @@ void quantize(
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
}
}
scales[i] = static_cast<T>(scale);
biases[i] = static_cast<T>(bias);
}
}
template <typename T, typename U>
void dispatch_quantize(
const array& w,
array& out,
array& scales,
array& biases,
int bits,
int group_size) {
auto w_ptr = w.data<T>();
auto out_ptr = out.data<U>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
quantize<T, U>(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
}
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [](const array& arr) {
auto ensure_row_contiguous = [s = stream()](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
return std::make_pair(arr, false);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
copy(arr, arr_copy, CopyType::General, s);
return std::make_pair(arr_copy, true);
}
};
auto w = ensure_row_contiguous(inputs[0]);
auto [w, copied] = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
if (copied) {
encoder.add_temporary(w);
}
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w = array::unsafe_weak_copy(w),
out = array::unsafe_weak_copy(out),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_]() mutable {
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<float16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<bfloat16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<float, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
});
}
} // namespace mlx::core

View File

@@ -5,6 +5,7 @@
#include <limits>
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
@@ -139,25 +140,22 @@ void reduction_op(
const array& x,
array& out,
const std::vector<int>& axes,
U init,
Op op) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
U init) {
ReductionPlan plan = get_reduction_plan(x, axes);
auto in_ptr = x.data<T>();
auto out_ptr = out.data<U>();
if (plan.type == ContiguousAllReduce) {
U* out_ptr = out.data<U>();
*out_ptr = init;
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);
return;
}
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {
*out_ptr = init;
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
}
return;
}
@@ -166,8 +164,6 @@ void reduction_op(
int reduction_size = plan.shape.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
@@ -175,7 +171,7 @@ void reduction_op(
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);
}
} else {
for (int i = 0; i < out.size(); i++, out_ptr++) {
@@ -184,10 +180,10 @@ void reduction_op(
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
x_ptr + offset + extra_offset,
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
op,
Op{},
init);
},
plan.shape,
@@ -202,12 +198,10 @@ void reduction_op(
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
x_ptr += reduction_stride * reduction_size;
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
in_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
return;
@@ -219,15 +213,14 @@ void reduction_op(
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
out_ptr += reduction_stride;
}
} else {
@@ -237,11 +230,11 @@ void reduction_op(
nd_loop(
[&](int extra_offset) {
strided_reduce(
x_ptr + offset + extra_offset,
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
op);
Op{});
},
plan.shape,
plan.strides);
@@ -252,15 +245,14 @@ void reduction_op(
}
if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) {
val = op(val, *(x_ptr + offset + extra_offset));
val = Op{}(val, *(in_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
@@ -396,9 +388,9 @@ void reduce_dispatch_and_or(
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::And) {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
reduction_op<InT, bool, AndReduce>(in, out, axes, true);
} else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
reduction_op<InT, bool, OrReduce>(in, out, axes, false);
}
}
@@ -410,15 +402,15 @@ void reduce_dispatch_sum_prod(
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);
} else {
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
reduction_op<InT, InT, SumReduce>(in, out, axes, 0);
}
} else {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);
} else {
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);
}
}
}
@@ -431,132 +423,141 @@ void reduce_dispatch_min_max(
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
reduction_op<InT, InT, MinReduce>(in, out, axes, init);
}
}
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
reduce_type_ = reduce_type_,
axes_ = axes_]() mutable {
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
break;
}
case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
break;
}
case Reduce::Max:
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
case Reduce::Max:
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
break;
}
}
});
}
} // namespace mlx::core

View File

@@ -4,6 +4,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
@@ -153,33 +154,31 @@ void strided_scan(
template <typename T, typename U, typename Op>
void scan_op(
const array& input,
array& output,
const array& in,
array& out,
int axis,
bool reverse,
bool inclusive,
const Op& op,
U init) {
output.set_data(allocator::malloc_or_wait(output.nbytes()));
if (input.flags().row_contiguous) {
if (input.strides()[axis] == 1) {
if (in.flags().row_contiguous) {
if (in.strides()[axis] == 1) {
contiguous_scan(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis),
input.shape(axis),
in.data<T>(),
out.data<U>(),
in.size() / in.shape(axis),
in.shape(axis),
reverse,
inclusive,
op,
init);
} else {
strided_scan(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis) / input.strides()[axis],
input.shape(axis),
input.strides()[axis],
in.data<T>(),
out.data<U>(),
in.size() / in.shape(axis) / in.strides()[axis],
in.shape(axis),
in.strides()[axis],
reverse,
inclusive,
op,
@@ -193,8 +192,8 @@ void scan_op(
template <typename T, typename U>
void scan_dispatch(
Scan::ReduceType rtype,
const array& input,
array& output,
const array& in,
array& out,
int axis,
bool reverse,
bool inclusive) {
@@ -202,29 +201,29 @@ void scan_dispatch(
case Scan::Sum: {
auto op = [](U y, T x) { return y + x; };
auto init = static_cast<U>(0);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
case Scan::Prod: {
auto op = [](U y, T x) { return y * x; };
auto init = static_cast<U>(1);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
case Scan::Min: {
auto op = [](U y, T x) { return x < y ? x : y; };
auto init = (issubdtype(input.dtype(), floating))
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
case Scan::Max: {
auto op = [](U y, T x) { return x < y ? y : x; };
auto init = (issubdtype(input.dtype(), floating))
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
}
@@ -235,82 +234,95 @@ void scan_dispatch(
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& encoder = cpu::get_command_encoder(stream());
// Ensure contiguity
auto in = inputs[0];
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General);
copy(in, arr_copy, CopyType::General, stream());
in = arr_copy;
encoder.add_temporary(arr_copy);
}
out.set_data(allocator::malloc(out.nbytes()));
switch (in.dtype()) {
case bool_: {
// We could do a full dtype x dtype switch but this is the only case
// where we accumulate in a different type, for now.
//
// TODO: If we add the option to accumulate floats in higher precision
// floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
} else {
scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_,
reduce_type_ = reduce_type_,
reverse_ = reverse_,
inclusive_ = inclusive_]() mutable {
switch (in.dtype()) {
case bool_: {
// We could do a full dtype x dtype switch but this is the only case
// where we accumulate in a different type, for now.
//
// TODO: If we add the option to accumulate floats in higher precision
// floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
} else {
scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
}
break;
}
break;
case uint8:
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
}
case uint8:
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
}
});
}
} // namespace mlx::core

View File

@@ -16,51 +16,70 @@ void select_op(
const array& b,
const array& c,
array& out,
Op op) {
switch (out.dtype()) {
case bool_:
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
break;
case uint8:
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
break;
case uint16:
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
break;
case uint32:
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
break;
case uint64:
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
break;
case int8:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
break;
case int16:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
break;
case int32:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
break;
case int64:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
break;
case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
break;
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op);
break;
case float64:
ternary_op<bool, double, double, double>(a, b, c, out, op);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
break;
case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
break;
}
Op op,
Stream stream) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
c = array::unsafe_weak_copy(c),
out = array::unsafe_weak_copy(out),
op,
topt]() mutable {
switch (out.dtype()) {
case bool_:
ternary_op<bool, bool, bool, bool>(a, b, c, out, op, topt);
break;
case uint8:
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op, topt);
break;
case uint16:
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op, topt);
break;
case uint32:
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op, topt);
break;
case uint64:
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op, topt);
break;
case int8:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op, topt);
break;
case int16:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op, topt);
break;
case int32:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op, topt);
break;
case int64:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op, topt);
break;
case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(
a, b, c, out, op, topt);
break;
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op, topt);
break;
case float64:
ternary_op<bool, double, double, double>(a, b, c, out, op, topt);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(
a, b, c, out, op, topt);
break;
case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(
a, b, c, out, op, topt);
break;
}
});
}
} // namespace
@@ -70,7 +89,7 @@ void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
const auto& condition = inputs[0];
const auto& a = inputs[1];
const auto& b = inputs[2];
select_op(condition, a, b, out, detail::Select());
select_op(condition, a, b, out, detail::Select(), stream());
}
} // namespace mlx::core

View File

@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
#endif
template <>
static constexpr int max_size<float16_t> = N;
inline constexpr int max_size<float16_t> = N;
#define SIMD_FP16_DEFAULT_UNARY(op) \
template <> \

View File

@@ -83,25 +83,25 @@ struct Simd {
// Values chosen based on benchmarks on M3 Max
// TODO: consider choosing these more optimally
template <>
static constexpr int max_size<int8_t> = 16;
inline constexpr int max_size<int8_t> = 16;
template <>
static constexpr int max_size<int16_t> = 16;
inline constexpr int max_size<int16_t> = 16;
template <>
static constexpr int max_size<int> = 8;
inline constexpr int max_size<int> = 8;
template <>
static constexpr int max_size<int64_t> = 4;
inline constexpr int max_size<int64_t> = 4;
template <>
static constexpr int max_size<uint8_t> = 16;
inline constexpr int max_size<uint8_t> = 16;
template <>
static constexpr int max_size<uint16_t> = 16;
inline constexpr int max_size<uint16_t> = 16;
template <>
static constexpr int max_size<uint32_t> = 8;
inline constexpr int max_size<uint32_t> = 8;
template <>
static constexpr int max_size<uint64_t> = 4;
inline constexpr int max_size<uint64_t> = 4;
template <>
static constexpr int max_size<float> = 8;
inline constexpr int max_size<float> = 8;
template <>
static constexpr int max_size<double> = 4;
inline constexpr int max_size<double> = 4;
#define SIMD_DEFAULT_UNARY(name, op) \
template <typename T, int N> \

View File

@@ -87,7 +87,6 @@ DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh)
@@ -95,6 +94,17 @@ DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto out = std::log(in.value);
auto scale = decltype(out.real())(M_LN2);
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
} else {
return Simd<T, 1>{std::log2(in.value)};
}
}
template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value;

View File

@@ -186,7 +186,7 @@ Simd<T, N> erfinv(Simd<T, N> a_) {
return a * rhs(t);
}
} else {
return a * select(t > thresh, lhs(t), rhs(t));
return a * select(abs(t) > thresh, lhs(t), rhs(t));
}
}

View File

@@ -4,6 +4,7 @@
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
@@ -15,92 +16,100 @@ namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void softmax(const array& in, array& out) {
constexpr bool same_t = std::is_same_v<T, AccT>;
constexpr int N = std::min(max_size<AccT>, max_size<T>);
void softmax(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr bool same_t = std::is_same_v<T, AccT>;
constexpr int N = std::min(max_size<AccT>, max_size<T>);
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
const T* current_in_ptr;
T* current_out_ptr;
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
if constexpr (same_t) {
store(current_out_ptr, vexp);
}
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if constexpr (same_t) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (same_t) {
store(
current_out_ptr,
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
} else {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum) * normalizer;
store(current_out_ptr, Simd<T, N>(vexp));
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (same_t) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
current_out_ptr++;
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
if constexpr (same_t) {
store(current_out_ptr, vexp);
}
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if constexpr (same_t) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (same_t) {
store(
current_out_ptr,
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
} else {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum) * normalizer;
store(current_out_ptr, Simd<T, N>(vexp));
current_in_ptr += N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (same_t) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
current_out_ptr++;
}
}
}
});
}
} // namespace
@@ -109,67 +118,52 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
auto set_output = [s = stream(), &out](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General);
copy(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
array in = check_input(std::move(inputs[0]));
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
auto in = set_output(inputs[0]);
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::runtime_error(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float, float>(in, out);
softmax<float, float>(in, out, stream());
break;
case float16:
if (precise_) {
softmax<float16_t, float>(in, out);
softmax<float16_t, float>(in, out, stream());
} else {
softmax<float16_t, float16_t>(in, out);
softmax<float16_t, float16_t>(in, out, stream());
}
break;
case bfloat16:
if (precise_) {
softmax<bfloat16_t, float>(in, out);
softmax<bfloat16_t, float>(in, out, stream());
} else {
softmax<bfloat16_t, bfloat16_t>(in, out);
softmax<bfloat16_t, bfloat16_t>(in, out, stream());
}
break;
case float64:
softmax<double, double>(in, out);
softmax<double, double>(in, out, stream());
break;
case complex64:
throw std::invalid_argument(
"[Softmax] Not yet implemented for complex64");
default:
throw std::runtime_error(
"[softmax] Only defined for floating point types.");
break;
}
}

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
@@ -103,16 +104,12 @@ struct StridedIterator {
T* ptr_;
};
template <typename T, typename IdxT = uint32_t>
void sort(const array& in, array& out, int axis) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
template <typename T>
void sort(array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -126,8 +123,9 @@ void sort(const array& in, array& out, int axis) {
// Perform sorting in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
T* data_ptr = out_ptr + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
@@ -139,9 +137,6 @@ void sort(const array& in, array& out, int axis) {
template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
@@ -167,9 +162,12 @@ void argsort(const array& in, array& out, int axis) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();
out_it.step();
@@ -191,33 +189,30 @@ void argsort(const array& in, array& out, int axis) {
}
}
template <typename T, typename IdxT = uint32_t>
void partition(const array& in, array& out, int axis, int kth) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
template <typename T>
void partition(array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = in.shape();
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
auto axis_stride = out.strides()[axis];
int axis_size = out.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
T* data_ptr = out_ptr + src_it.loc;
src_it.step();
StridedIterator st(data_ptr, axis_stride, 0);
@@ -230,9 +225,6 @@ void partition(const array& in, array& out, int axis, int kth) {
template <typename T, typename IdxT = uint32_t>
void argpartition(const array& in, array& out, int axis, int kth) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
@@ -260,9 +252,13 @@ void argpartition(const array& in, array& out, int axis, int kth) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();
out_it.step();
@@ -291,144 +287,184 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return argsort<bool>(in, out, axis_);
case uint8:
return argsort<uint8_t>(in, out, axis_);
case uint16:
return argsort<uint16_t>(in, out, axis_);
case uint32:
return argsort<uint32_t>(in, out, axis_);
case uint64:
return argsort<uint64_t>(in, out, axis_);
case int8:
return argsort<int8_t>(in, out, axis_);
case int16:
return argsort<int16_t>(in, out, axis_);
case int32:
return argsort<int32_t>(in, out, axis_);
case int64:
return argsort<int64_t>(in, out, axis_);
case float32:
return argsort<float>(in, out, axis_);
case float64:
return argsort<double>(in, out, axis_);
case float16:
return argsort<float16_t>(in, out, axis_);
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_);
case complex64:
return argsort<complex64_t>(in, out, axis_);
}
// Allocate output
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_]() mutable {
switch (in.dtype()) {
case bool_:
return argsort<bool>(in, out, axis_);
case uint8:
return argsort<uint8_t>(in, out, axis_);
case uint16:
return argsort<uint16_t>(in, out, axis_);
case uint32:
return argsort<uint32_t>(in, out, axis_);
case uint64:
return argsort<uint64_t>(in, out, axis_);
case int8:
return argsort<int8_t>(in, out, axis_);
case int16:
return argsort<int16_t>(in, out, axis_);
case int32:
return argsort<int32_t>(in, out, axis_);
case int64:
return argsort<int64_t>(in, out, axis_);
case float32:
return argsort<float>(in, out, axis_);
case float64:
return argsort<double>(in, out, axis_);
case float16:
return argsort<float16_t>(in, out, axis_);
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_);
case complex64:
return argsort<complex64_t>(in, out, axis_);
}
});
}
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return sort<bool>(in, out, axis_);
case uint8:
return sort<uint8_t>(in, out, axis_);
case uint16:
return sort<uint16_t>(in, out, axis_);
case uint32:
return sort<uint32_t>(in, out, axis_);
case uint64:
return sort<uint64_t>(in, out, axis_);
case int8:
return sort<int8_t>(in, out, axis_);
case int16:
return sort<int16_t>(in, out, axis_);
case int32:
return sort<int32_t>(in, out, axis_);
case int64:
return sort<int64_t>(in, out, axis_);
case float32:
return sort<float>(in, out, axis_);
case float64:
return sort<double>(in, out, axis_);
case float16:
return sort<float16_t>(in, out, axis_);
case bfloat16:
return sort<bfloat16_t>(in, out, axis_);
case complex64:
return sort<complex64_t>(in, out, axis_);
}
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch(
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
switch (out.dtype()) {
case bool_:
return sort<bool>(out, axis_);
case uint8:
return sort<uint8_t>(out, axis_);
case uint16:
return sort<uint16_t>(out, axis_);
case uint32:
return sort<uint32_t>(out, axis_);
case uint64:
return sort<uint64_t>(out, axis_);
case int8:
return sort<int8_t>(out, axis_);
case int16:
return sort<int16_t>(out, axis_);
case int32:
return sort<int32_t>(out, axis_);
case int64:
return sort<int64_t>(out, axis_);
case float32:
return sort<float>(out, axis_);
case float64:
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(out, axis_);
}
});
}
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return argpartition<bool>(in, out, axis_, kth_);
case uint8:
return argpartition<uint8_t>(in, out, axis_, kth_);
case uint16:
return argpartition<uint16_t>(in, out, axis_, kth_);
case uint32:
return argpartition<uint32_t>(in, out, axis_, kth_);
case uint64:
return argpartition<uint64_t>(in, out, axis_, kth_);
case int8:
return argpartition<int8_t>(in, out, axis_, kth_);
case int16:
return argpartition<int16_t>(in, out, axis_, kth_);
case int32:
return argpartition<int32_t>(in, out, axis_, kth_);
case int64:
return argpartition<int64_t>(in, out, axis_, kth_);
case float32:
return argpartition<float>(in, out, axis_, kth_);
case float64:
return argpartition<double>(in, out, axis_, kth_);
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_);
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_);
}
// Allocate output
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_,
kth_ = kth_]() mutable {
switch (in.dtype()) {
case bool_:
return argpartition<bool>(in, out, axis_, kth_);
case uint8:
return argpartition<uint8_t>(in, out, axis_, kth_);
case uint16:
return argpartition<uint16_t>(in, out, axis_, kth_);
case uint32:
return argpartition<uint32_t>(in, out, axis_, kth_);
case uint64:
return argpartition<uint64_t>(in, out, axis_, kth_);
case int8:
return argpartition<int8_t>(in, out, axis_, kth_);
case int16:
return argpartition<int16_t>(in, out, axis_, kth_);
case int32:
return argpartition<int32_t>(in, out, axis_, kth_);
case int64:
return argpartition<int64_t>(in, out, axis_, kth_);
case float32:
return argpartition<float>(in, out, axis_, kth_);
case float64:
return argpartition<double>(in, out, axis_, kth_);
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_);
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_);
}
});
}
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return partition<bool>(in, out, axis_, kth_);
case uint8:
return partition<uint8_t>(in, out, axis_, kth_);
case uint16:
return partition<uint16_t>(in, out, axis_, kth_);
case uint32:
return partition<uint32_t>(in, out, axis_, kth_);
case uint64:
return partition<uint64_t>(in, out, axis_, kth_);
case int8:
return partition<int8_t>(in, out, axis_, kth_);
case int16:
return partition<int16_t>(in, out, axis_, kth_);
case int32:
return partition<int32_t>(in, out, axis_, kth_);
case int64:
return partition<int64_t>(in, out, axis_, kth_);
case float32:
return partition<float>(in, out, axis_, kth_);
case float64:
return partition<double>(in, out, axis_, kth_);
case float16:
return partition<float16_t>(in, out, axis_, kth_);
case bfloat16:
return partition<bfloat16_t>(in, out, axis_, kth_);
case complex64:
return partition<complex64_t>(in, out, axis_, kth_);
}
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
axis_ = axis_,
kth_ = kth_]() mutable {
switch (out.dtype()) {
case bool_:
return partition<bool>(out, axis_, kth_);
case uint8:
return partition<uint8_t>(out, axis_, kth_);
case uint16:
return partition<uint16_t>(out, axis_, kth_);
case uint32:
return partition<uint32_t>(out, axis_, kth_);
case uint64:
return partition<uint64_t>(out, axis_, kth_);
case int8:
return partition<int8_t>(out, axis_, kth_);
case int16:
return partition<int16_t>(out, axis_, kth_);
case int32:
return partition<int32_t>(out, axis_, kth_);
case int64:
return partition<int64_t>(out, axis_, kth_);
case float32:
return partition<float>(out, axis_, kth_);
case float64:
return partition<double>(out, axis_, kth_);
case float16:
return partition<float16_t>(out, axis_, kth_);
case bfloat16:
return partition<bfloat16_t>(out, axis_, kth_);
case complex64:
return partition<complex64_t>(out, axis_, kth_);
}
});
}
} // namespace mlx::core

View File

@@ -2,12 +2,18 @@
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
void svd_impl(const array& a, array& u, array& s, array& vt) {
template <typename T>
void svd_impl(
const array& a,
std::vector<array>& outputs,
bool compute_uv,
Stream stream) {
// Lapack uses the column-major convention. To avoid having to transpose
// the input and then transpose the outputs, we swap the indices/sizes of the
// matrices and take advantage of the following identity (see
@@ -21,129 +27,179 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
const int N = a.shape(-1);
const int K = std::min(M, N);
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), float32, nullptr, {});
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
array in(a.shape(), a.dtype(), nullptr, {});
copy(
a,
in,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream);
// Allocate outputs.
u.set_data(allocator::malloc_or_wait(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
auto in_ptr = in.data<T>();
T* u_ptr;
T* s_ptr;
T* vt_ptr;
static constexpr auto job_u = "V";
static constexpr auto job_vt = "V";
static constexpr auto range = "A";
if (compute_uv) {
array& u = outputs[0];
array& s = outputs[1];
array& vt = outputs[2];
// Will contain the number of singular values after the call has returned.
int ns = 0;
float workspace_dimension = 0;
u.set_data(allocator::malloc(u.nbytes()));
s.set_data(allocator::malloc(s.nbytes()));
vt.set_data(allocator::malloc(vt.nbytes()));
// Will contain the indices of eigenvectors that failed to converge (not used
// here but required by lapack).
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
encoder.set_output_array(u);
encoder.set_output_array(s);
encoder.set_output_array(vt);
static const int lwork_query = -1;
s_ptr = s.data<T>();
u_ptr = u.data<T>();
vt_ptr = vt.data<T>();
} else {
array& s = outputs[0];
static const int ignored_int = 0;
static const float ignored_float = 0;
s.set_data(allocator::malloc(s.nbytes()));
int info;
encoder.set_output_array(s);
// Compute workspace size.
MLX_LAPACK_FUNC(sgesvdx)
(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
s_ptr = s.data<T>();
u_ptr = nullptr;
vt_ptr = nullptr;
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
MLX_LAPACK_FUNC(sgesvdx)
(
auto job_u = (u_ptr) ? "V" : "N";
auto job_vt = (u_ptr) ? "V" : "N";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned.
int ns = 0;
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
static const int ignored_int = 0;
static const T ignored_float = 0;
int info;
// Compute workspace size.
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in.data<float>() + M * N * i,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s.data<float>() + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt.data<float>() + N * N * i,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u.data<float>() + M * M * i,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
gesvdx<T>(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
}
}
});
encoder.add_temporary(in);
}
template <typename T>
void compute_svd(
const array& a,
bool compute_uv,
std::vector<array>& outputs,
Stream stream) {}
void SVD::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[SVD::eval] only supports float32.");
switch (inputs[0].dtype()) {
case float32:
svd_impl<float>(inputs[0], outputs, compute_uv_, stream());
break;
case float64:
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
break;
default:
throw std::runtime_error(
"[SVD::eval_cpu] only supports float32 or float64.");
}
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
}
} // namespace mlx::core

View File

@@ -1,10 +1,10 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/ternary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
namespace mlx::core {
@@ -53,22 +53,18 @@ void ternary_op_dims(
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dispatch_dims(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
const T1* a_ptr,
const T2* b_ptr,
const T3* c_ptr,
U* out_ptr,
Op op,
size_t size,
Shape& shape,
std::vector<Strides>& strides) {
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& c_strides = strides[2];
const auto& out_strides = strides[3];
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<T3>();
int ndim = shape.size();
switch (ndim) {
case 1:
@@ -105,7 +101,7 @@ void ternary_op_dispatch_dims(
ContiguousIterator b_it(shape, b_strides, ndim - 2);
ContiguousIterator c_it(shape, c_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
for (size_t elem = 0; elem < size; elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
@@ -130,18 +126,16 @@ void ternary_op(
const array& b,
const array& c,
array& out,
Op op) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
Op op,
TernaryOpType topt) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
// The full computation is scalar-scalar-scalar so we call the base op once.
if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
} else if (topt == TernaryOpType::VectorVectorVector) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
@@ -150,7 +144,10 @@ void ternary_op(
out_ptr++;
}
} else {
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
ternary_op_dispatch_dims<T1, T2, T3, U>(
a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides);
}
}

View File

@@ -14,88 +14,57 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
auto op = detail::Abs{};
switch (out.dtype()) {
case int8:
unary_op<int8_t>(in, out, op);
break;
case int16:
unary_op<int16_t>(in, out, op);
break;
case int32:
unary_op<int32_t>(in, out, op);
break;
case int64:
unary_op<int64_t>(in, out, op);
break;
case float16:
unary_op<float16_t>(in, out, op);
break;
case float32:
unary_op<float>(in, out, op);
break;
case float64:
unary_op<double>(in, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, op);
break;
case complex64:
unary_op<complex64_t>(in, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
unary_signed(in, out, detail::Abs(), stream());
}
}
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCos());
unary_fp(in, out, detail::ArcCos(), stream());
}
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCosh());
unary_fp(in, out, detail::ArcCosh(), stream());
}
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSin());
unary_fp(in, out, detail::ArcSin(), stream());
}
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSinh());
unary_fp(in, out, detail::ArcSinh(), stream());
}
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTan());
unary_fp(in, out, detail::ArcTan(), stream());
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTanh());
unary_fp(in, out, detail::ArcTanh(), stream());
}
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_int(in, out, detail::BitwiseInvert());
unary_int(in, out, detail::BitwiseInvert(), stream());
}
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil());
unary_fp(in, out, detail::Ceil(), stream());
} else {
// No-op integer types
out.copy_shared_buffer(in);
@@ -104,84 +73,50 @@ void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
unary_op<complex64_t>(inputs[0], out, detail::Conjugate());
unary_complex(inputs[0], out, detail::Conjugate(), stream());
}
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Cos());
unary_fp(in, out, detail::Cos(), stream());
}
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Cosh());
unary_fp(in, out, detail::Cosh(), stream());
}
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::Erf());
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case float64:
unary_op<double>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
unary_real_fp(in, out, detail::Erf(), stream());
}
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::ErfInv());
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case float64:
unary_op<double>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
unary_real_fp(in, out, detail::ErfInv(), stream());
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Exp());
unary_fp(in, out, detail::Exp(), stream());
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Expm1());
unary_fp(in, out, detail::Expm1(), stream());
}
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor());
unary_fp(in, out, detail::Floor(), stream());
} else {
// No-op integer types
out.copy_shared_buffer(in);
@@ -189,7 +124,7 @@ void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
unary_complex_to_float(inputs[0], out, detail::Imag(), stream());
}
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -197,13 +132,13 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0];
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log());
unary_fp(in, out, detail::Log(), stream());
break;
case Base::two:
unary_fp(in, out, detail::Log2());
unary_fp(in, out, detail::Log2(), stream());
break;
case Base::ten:
unary_fp(in, out, detail::Log10());
unary_fp(in, out, detail::Log10(), stream());
break;
}
}
@@ -211,30 +146,30 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Log1p());
unary_fp(in, out, detail::Log1p(), stream());
}
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::LogicalNot());
unary(in, out, detail::LogicalNot(), stream());
}
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Negative());
unary(in, out, detail::Negative(), stream());
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
unary_complex_to_float(inputs[0], out, detail::Real(), stream());
}
void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round());
unary_fp(in, out, detail::Round(), stream());
} else {
// No-op integer types
out.copy_shared_buffer(in);
@@ -244,7 +179,7 @@ void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sigmoid());
unary_fp(in, out, detail::Sigmoid(), stream());
}
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -253,48 +188,48 @@ void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) {
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Sign());
unary(in, out, detail::Sign(), stream());
}
}
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sin());
unary_fp(in, out, detail::Sin(), stream());
}
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sinh());
unary_fp(in, out, detail::Sinh(), stream());
}
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Square());
unary(in, out, detail::Square(), stream());
}
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (recip_) {
unary_fp(in, out, detail::Rsqrt());
unary_fp(in, out, detail::Rsqrt(), stream());
} else {
unary_fp(in, out, detail::Sqrt());
unary_fp(in, out, detail::Sqrt(), stream());
}
}
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Tan());
unary_fp(in, out, detail::Tan(), stream());
}
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Tanh());
unary_fp(in, out, detail::Tanh(), stream());
}
} // namespace mlx::core

View File

@@ -5,174 +5,296 @@
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/utils.h"
namespace mlx::core {
void set_unary_output_data(const array& in, array& out) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
out.set_data(allocator::malloc(out.nbytes()));
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = op(*a);
out[i] = Op{}(*a);
a += stride;
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
void unary_op(const array& a, array& out, Op) {
const T* src = a.data<T>();
U* dst = out.data<U>();
auto ndim = a.ndim();
if (a.flags().contiguous) {
set_unary_output_data(a, out);
U* dst = out.data<U>();
auto size = a.data_size();
constexpr int N = simd::max_size<T>;
size_t size = a.data_size();
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a_ptr)));
simd::store(dst, Op{}(simd::load<T, N>(src)));
size -= N;
a_ptr += N;
src += N;
dst += N;
}
while (size > 0) {
*dst = op(*a_ptr);
*dst = Op{}(*src);
size--;
dst++;
a_ptr++;
src++;
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
U* dst = out.data<U>();
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
if (a.ndim() <= 1) {
unary_op(a_ptr, dst, op, shape, stride);
size_t shape = ndim > 0 ? a.shape().back() : 1;
size_t stride = ndim > 0 ? a.strides().back() : 1;
if (ndim <= 1) {
unary_op<T, U, Op>(src, dst, shape, stride);
return;
}
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step();
}
}
}
template <typename Op>
void unary(const array& a, array& out, Op op) {
switch (out.dtype()) {
case bool_:
unary_op<bool>(a, out, op);
break;
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
}
void unary(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case bool_:
unary_op<bool>(a, out, op);
break;
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
}
});
}
template <typename Op>
void unary_fp(const array& a, array& out, Op op) {
switch (out.dtype()) {
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_fp] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
void unary_real_fp(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_real] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
}
template <typename Op>
void unary_fp(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_fp] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
}
template <typename Op>
void unary_int(const array& a, array& out, Op op) {
switch (out.dtype()) {
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_int] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
void unary_signed(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
});
}
template <typename Op>
void unary_complex(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable { unary_op<complex64_t>(a, out, op); });
}
template <typename Op>
void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch(
[a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable { unary_op<complex64_t, float>(a, out, op); });
}
template <typename Op>
void unary_int(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_int] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
}
} // namespace mlx::core

View File

@@ -86,13 +86,14 @@ struct Sign {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
auto z = Simd<T, N>{0};
auto o = Simd<T, N>{1};
auto m = Simd<T, N>{-1};
if constexpr (std::is_unsigned_v<T>) {
return x != z;
return simd::select(x == z, z, o);
} else if constexpr (std::is_same_v<T, complex64_t>) {
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
} else {
return simd::select(
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
return simd::select(x < z, m, simd::select(x > z, o, z));
}
}
SINGLE()

View File

@@ -47,6 +47,7 @@ if(MLX_METAL_JIT)
make_jit_source(binary)
make_jit_source(binary_two)
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
make_jit_source(logsumexp)
make_jit_source(ternary)
make_jit_source(softmax)
make_jit_source(scan)
@@ -95,6 +96,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp

View File

@@ -3,6 +3,7 @@
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include "mlx/memory.h"
#include <mach/vm_page_size.h>
#include <unistd.h>
@@ -10,6 +11,9 @@
namespace mlx::core {
constexpr size_t resource_options =
MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeUntracked;
namespace allocator {
Allocator& allocator() {
@@ -17,6 +21,9 @@ Allocator& allocator() {
}
void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<MTL::Buffer*>(ptr_)->contents();
}
@@ -26,8 +33,11 @@ namespace metal {
namespace {
BufferCache::BufferCache(MTL::Device* device)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
BufferCache::BufferCache(ResidencySet& residency_set)
: head_(nullptr),
tail_(nullptr),
pool_size_(0),
residency_set_(residency_set) {}
BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool();
@@ -38,6 +48,9 @@ int BufferCache::clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) {
if (!holder->buf->heap()) {
residency_set_.erase(holder->buf);
}
holder->buf->release();
n_release++;
}
@@ -95,6 +108,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) {
total_bytes_freed += tail_->buf->length();
if (!tail_->buf->heap()) {
residency_set_.erase(tail_->buf);
}
tail_->buf->release();
tail_->buf = nullptr;
n_release++;
@@ -149,16 +165,35 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]);
buffer_cache_(residency_set_) {
auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size =
std::get<size_t>(device_info()["max_recommended_working_set_size"]);
resource_limit_ = std::get<size_t>(device_info()["resource_limit"]);
std::get<size_t>(device_info().at("max_recommended_working_set_size"));
resource_limit_ = std::get<size_t>(device_info().at("resource_limit"));
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
bool is_vm = std::get<std::string>(device_info().at("device_name")) ==
"Apple Paravirtual device";
if (is_vm) {
return;
}
auto heap_desc = MTL::HeapDescriptor::alloc()->init();
heap_desc->setResourceOptions(resource_options);
heap_desc->setSize(heap_size_);
heap_ = device_->newHeap(heap_desc);
heap_desc->release();
residency_set_.insert(heap_);
}
MetalAllocator::~MetalAllocator() {
auto pool = metal::new_scoped_memory_pool();
if (heap_) {
heap_->release();
}
}
size_t MetalAllocator::set_cache_limit(size_t limit) {
@@ -167,16 +202,19 @@ size_t MetalAllocator::set_cache_limit(size_t limit) {
return limit;
};
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
size_t MetalAllocator::set_memory_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, block_limit_);
relaxed_ = relaxed;
gc_limit_ = std::min(
block_limit_,
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
return limit;
};
size_t MetalAllocator::get_memory_limit() {
return block_limit_;
}
size_t MetalAllocator::set_wired_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, wired_limit_);
@@ -184,7 +222,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) {
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
Buffer MetalAllocator::malloc(size_t size) {
// Metal doesn't like empty buffers
if (size == 0) {
return Buffer{nullptr};
@@ -211,11 +249,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
if (!buf) {
size_t mem_required = get_active_memory() + get_cache_memory() + size;
// If there is too much memory pressure, fail (likely causes a wait).
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
return Buffer{nullptr};
}
auto pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size,
@@ -226,8 +259,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
}
// Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
if (num_resources_ >= resource_limit_) {
std::ostringstream msg;
msg << "[metal::malloc] Resource limit (" << resource_limit_
@@ -235,10 +266,19 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
throw std::runtime_error(msg.str());
}
lk.unlock();
buf = device_->newBuffer(size, res_opt);
if (size < small_size_ && heap_) {
buf = heap_->newBuffer(size, resource_options);
}
if (!buf) {
buf = device_->newBuffer(size, resource_options);
}
if (!buf) {
return Buffer{nullptr};
}
lk.lock();
if (buf) {
num_resources_++;
num_resources_++;
if (!buf->heap()) {
residency_set_.insert(buf);
}
}
@@ -246,14 +286,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
peak_memory_ = std::max(peak_memory_, active_memory_);
// Maintain the cache below the requested limit
if (get_cache_memory() >= max_pool_size_) {
if (get_cache_memory() > max_pool_size_) {
auto pool = metal::new_scoped_memory_pool();
num_resources_ -= buffer_cache_.release_cached_buffers(
get_cache_memory() - max_pool_size_);
}
residency_set_.insert(buf);
return Buffer{static_cast<void*>(buf)};
}
@@ -269,12 +307,14 @@ void MetalAllocator::free(Buffer buffer) {
return;
}
std::unique_lock lk(mutex_);
residency_set_.erase(buf);
active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
num_resources_--;
if (!buf->heap()) {
residency_set_.erase(buf);
}
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();
@@ -293,37 +333,40 @@ MetalAllocator& allocator() {
return *allocator_;
}
} // namespace metal
size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit);
return metal::allocator().set_cache_limit(limit);
}
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed);
size_t set_memory_limit(size_t limit) {
return metal::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return metal::allocator().get_memory_limit();
}
size_t set_wired_limit(size_t limit) {
if (limit >
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
if (limit > std::get<size_t>(metal::device_info().at(
"max_recommended_working_set_size"))) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");
}
return allocator().set_wired_limit(limit);
return metal::allocator().set_wired_limit(limit);
}
size_t get_active_memory() {
return allocator().get_active_memory();
return metal::allocator().get_active_memory();
}
size_t get_peak_memory() {
return allocator().get_peak_memory();
return metal::allocator().get_peak_memory();
}
void reset_peak_memory() {
allocator().reset_peak_memory();
metal::allocator().reset_peak_memory();
}
size_t get_cache_memory() {
return allocator().get_cache_memory();
return metal::allocator().get_cache_memory();
}
void clear_cache() {
return allocator().clear_cache();
return metal::allocator().clear_cache();
}
} // namespace metal
} // namespace mlx::core

View File

@@ -18,7 +18,7 @@ namespace {
class BufferCache {
public:
BufferCache(MTL::Device* device);
BufferCache(ResidencySet& residency_set);
~BufferCache();
MTL::Buffer* reuse_from_cache(size_t size);
@@ -42,12 +42,11 @@ class BufferCache {
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
BufferHolder* tail_;
size_t pool_size_;
ResidencySet& residency_set_;
};
} // namespace
@@ -55,7 +54,7 @@ class BufferCache {
class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() {
@@ -72,13 +71,22 @@ class MetalAllocator : public allocator::Allocator {
return buffer_cache_.cache_size();
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
size_t set_memory_limit(size_t limit);
size_t get_memory_limit();
size_t set_wired_limit(size_t limit);
void clear_cache();
private:
MTL::Device* device_;
// The size of allocations which go on the heap until it is full. This size
// is chosen because it is the actual minimum size of a buffer allocated from
// the heap, a heap can have at most heap.size() / 256 buffers.
static constexpr int small_size_ = 256;
static constexpr int heap_size_ = 1 << 20;
MTL::Heap* heap_;
MetalAllocator();
~MetalAllocator();
friend MetalAllocator& allocator();
// Caching allocator

View File

@@ -102,16 +102,9 @@ void binary_op_gpu_inplace(
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output.
// - If there is only one output only one of a and b will be donated.
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
int arg_idx = 0;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++);
compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++);
compute_encoder.set_input_array(a, arg_idx++);
compute_encoder.set_input_array(b, arg_idx++);
compute_encoder.set_output_array(outputs[0], arg_idx++);
if (outputs.size() == 2) {
compute_encoder.set_output_array(outputs[1], arg_idx++);
@@ -164,8 +157,8 @@ void binary_op_gpu(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace(inputs, outputs, op, s);
}
@@ -195,7 +188,7 @@ void binary_op_gpu(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true);
set_binary_op_output_data(a, b, out, bopt);
binary_op_gpu_inplace(inputs, out, op, s);
}

View File

@@ -457,7 +457,7 @@ void Compiled::eval_gpu(
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, true);
inputs, outputs, inputs_, constant_ids_, contiguous);
// Put the outputs in
for (auto& x : outputs) {

View File

@@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu(
Shape unfolded_shape{implicit_M, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
@@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu(
// Prepare unfolding array
Shape unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
@@ -533,45 +533,6 @@ void implicit_gemm_conv_2D_general_gpu(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_fused_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies_w) {
int O_c = conv_params.O;
int C_c = conv_params.C;
int N_tiles_n = conv_params.N;
int N_tiles_h = (conv_params.oS[0] + 1) / 2;
int N_tiles_w = (conv_params.oS[1] + 1) / 2;
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
int bc = 32;
int wm = 4;
int wn = 1;
std::ostringstream kname;
kname << "winograd_conv_2d_fused_" << type_to_name(out) << "_flip"
<< conv_params.flip;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
MTL::Size group_dims = MTL::Size(8, 8, 2);
MTL::Size grid_dims =
MTL::Size(O_c / 8, (N_tiles_h * N_tiles_w) / 8, N_tiles_n);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
const Stream& s,
metal::Device& d,
@@ -580,6 +541,67 @@ void winograd_conv_2D_gpu(
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies_w) {
Shape padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
conv_params.C};
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
// Fill with zeros
array zero_arr = array(0, in.dtype());
fill_gpu(zero_arr, in_padded, s);
copies_w.push_back(zero_arr);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
conv_params.pad[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
copies_w.push_back(in_padded_slice);
copies_w.push_back(in_padded);
MLXConvParams<2> conv_params_updated{
/* const int N = */ static_cast<int>(in_padded.shape(0)),
/* const int C = */ static_cast<int>(in_padded.shape(3)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */
{static_cast<int>(in_padded.shape(1)),
static_cast<int>(in_padded.shape(2))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
/* const int str[NDIM] = */ {1, 1},
/* const int pad[NDIM] = */ {0, 0},
/* const int kdil[NDIM] = */ {1, 1},
/* const int idil[NDIM] = */ {1, 1},
/* const size_t in_strides[NDIM + 2] = */
{in_padded.strides()[0],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
/* const int groups = */ 1,
/* const bool flip = */ false,
};
int O_c = conv_params.O;
int C_c = conv_params.C;
@@ -591,14 +613,14 @@ void winograd_conv_2D_gpu(
// Do filter transform
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
copies_w.push_back(filt_wg);
{
int bc = 32;
int bo = 4;
std::ostringstream kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
<< bc << "_flip" << conv_params.flip;
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
@@ -618,7 +640,7 @@ void winograd_conv_2D_gpu(
// Do input transform
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
copies_w.push_back(inp_wg);
{
int bc = 32;
@@ -631,10 +653,10 @@ void winograd_conv_2D_gpu(
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_output_array(inp_wg, 1);
compute_encoder.set_bytes(conv_params, 2);
compute_encoder.set_bytes(conv_params_updated, 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
@@ -645,7 +667,7 @@ void winograd_conv_2D_gpu(
// Do batched gemm
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
out_wg.set_data(allocator::malloc(out_wg.nbytes()));
copies_w.push_back(out_wg);
{
std::vector<array> empty_copies;
@@ -681,7 +703,7 @@ void winograd_conv_2D_gpu(
compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(conv_params, 2);
compute_encoder.set_bytes(conv_params_updated, 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
@@ -690,6 +712,65 @@ void winograd_conv_2D_gpu(
}
}
void depthwise_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();
const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
const int ker_w = conv_params.wS[1];
const int str_h = conv_params.str[0];
const int str_w = conv_params.str[1];
const int tc = 8;
const int tw = 8;
const int th = 4;
const bool do_flip = conv_params.flip;
metal::MTLFCList func_consts = {
{&ker_h, MTL::DataType::DataTypeInt, 00},
{&ker_w, MTL::DataType::DataTypeInt, 01},
{&str_h, MTL::DataType::DataTypeInt, 10},
{&str_w, MTL::DataType::DataTypeInt, 11},
{&th, MTL::DataType::DataTypeInt, 100},
{&tw, MTL::DataType::DataTypeInt, 101},
{&do_flip, MTL::DataType::DataTypeBool, 200},
};
// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
MTL::Size group_dims = MTL::Size(tc, tw, th);
MTL::Size grid_dims = MTL::Size(
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void conv_2D_gpu(
const Stream& s,
metal::Device& d,
@@ -732,11 +813,20 @@ void conv_2D_gpu(
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (groups > 1) {
if (is_idil_one && groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
@@ -745,18 +835,14 @@ void conv_2D_gpu(
}
// Direct to winograd conv
bool img_large =
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && is_stride_one &&
is_kdil_one && is_idil_one) {
if (img_large && channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
if (conv_params.N <= 1) {
return winograd_conv_2D_fused_gpu(s, d, in, wt, out, conv_params, copies);
}
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
@@ -837,7 +923,7 @@ void conv_3D_gpu(
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
@@ -858,40 +944,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
wt = arr_copy;
}
// Check for 1x1 conv
auto is_one = [](int x) { return x == 1; };
auto is_zero = [](int x) { return x == 0; };
if (groups_ == 1 && (wt.shape(0) * wt.shape(-1) == wt.size()) &&
std::all_of(wt.shape().begin() + 1, wt.shape().end() - 1, is_one) &&
std::all_of(kernel_strides_.begin(), kernel_strides_.end(), is_one) &&
std::all_of(input_dilation_.begin(), input_dilation_.end(), is_one) &&
std::all_of(kernel_dilation_.begin(), kernel_dilation_.end(), is_one) &&
std::all_of(padding_.begin(), padding_.end(), is_zero)) {
std::vector<array> empty_copies;
steel_matmul_regular(
s,
d,
/*a = */ in,
/*b = */ wt,
/*c = */ out,
/*M = */ in.size() / in.shape(-1),
/*N = */ wt.shape(0),
/*K = */ in.shape(-1),
/*batch_size_out = */ 1,
/*lda = */ in.shape(-1),
/*ldb = */ wt.shape(-1),
/*ldd = */ wt.shape(0),
/*transpose_a = */ false,
/*transpose_b = */ true,
/*batch_shape = */ {1},
/*batch_strides = */ {1},
/*A_batch_stride = */ 0,
/*B_batch_stride = */ 0,
/*matrix_stride_out = */ 0,
/*copies = */ empty_copies);
}
// 3D conv
else if (out.ndim() == 5) {
if (out.ndim() == 5) {
conv_3D_gpu(
s,
d,

View File

@@ -14,25 +14,11 @@ namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
if (in.dtype() == out.dtype()) {
return;
}
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
@@ -216,7 +202,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
bool large = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +

View File

@@ -19,7 +19,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc(out.nbytes()));
}
}

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <sstream>
@@ -19,9 +18,6 @@ namespace mlx::core::metal {
namespace {
// TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr const char* default_mtllib_path = METAL_PATH;
auto get_metal_version() {
@@ -168,9 +164,10 @@ void CommandEncoder::set_output_array(
register_output_array(a);
}
void CommandEncoder::register_output_array(array& a) {
void CommandEncoder::register_output_array(const array& a) {
all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (concurrent_) {
concurrent_outputs_.insert(buf);
} else {
@@ -249,15 +246,13 @@ Device::~Device() {
for (auto& l : library_map_) {
l.second->release();
}
stream_map_.clear();
device_->release();
}
void Device::new_queue(int index) {
auto thread_pool = metal::new_scoped_memory_pool();
// Multiple threads can ask the device for queues
// We lock this as a critical section for safety
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
auto q = device_->newCommandQueue();
debug_set_stream_queue_label(q, index);
if (!q) {
throw std::runtime_error(
@@ -269,6 +264,10 @@ void Device::new_queue(int index) {
}
}
MTL::CommandQueue* Device::get_queue(Stream stream) {
return get_stream_(stream.index).queue;
}
bool Device::command_buffer_needs_commit(int index) {
auto& stream = get_stream_(index);
if (stream.buffer_ops > max_ops_per_buffer_ ||
@@ -690,12 +689,13 @@ void new_stream(Stream stream) {
}
}
std::unordered_map<std::string, std::variant<std::string, size_t>>
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
auto init_device_info = []()
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto name = std::string(raw_device->name()->utf8String());
auto arch = std::string(raw_device->architecture()->name()->utf8String());
size_t memsize = 0;
@@ -709,6 +709,7 @@ device_info() {
}
return {
{"device_name", name},
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",

Some files were not shown because too many files have changed in this diff Show More