Compare commits

..

50 Commits

Author SHA1 Message Date
Awni Hannun
bf7cd29970 version bump (#698) 2024-02-16 08:44:08 -08:00
Nripesh Niketan
a000d2288c feat: update black pre-commit hook to 24.2.0 (#696) 2024-02-16 06:01:59 -08:00
Mike Drob
165abf0e4c Auto-run PRs from contributors (#692) 2024-02-15 17:30:35 -08:00
Srimukh Sripada
818cda16bc Support LR schedulers (#334)
* Add a few LR schedulers

* Move parents's constructor call to the top

* Fix docstring

* refactor optimizers into two files

* add docs

* nit

* Fix Callable type annotation for python 3.8

---------

Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-15 11:26:20 -08:00
toji
85143fecdd improved error msg for invalid axis(mx.split) (#685)
* improved error msg for invalid axis(`mx.split`)

* Apply suggestions from code review

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fixed formatting issue

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-02-15 07:25:38 -08:00
Diogo
35431a4ac8 Adds device context manager (#679) 2024-02-14 14:14:58 -08:00
Awni Hannun
ccf1645995 Custom primitive + RoPE fat op (#676)
* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 14:04:25 -08:00
Jagrit Digani
1a48713d32 Update gather and scatter to not use Argument Encoder (#683)
* Replace argument encoder usage for gather and scatter

* Use constant address space for shapes and strides

* Split gather and scatter to improve compile times

* Enable the GPU tests

* Update the CI config

* Fix scatter dispatch for scalar indices

* Remove arg encoder utils

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-14 13:42:13 -08:00
Awni Hannun
1eb04aa23f Fix empty array construction in cpp (#684) 2024-02-13 23:34:17 -08:00
Noah Farr
0c65517e91 Return empty array when repeats is 0 in mx.repeat (#681)
* Return empty array when repeats is 0

* Add test case for repeats = 0
2024-02-13 17:49:31 -08:00
Vijay Krish
2fdc2462c3 Faster gather and scatter. (#682)
Reduce unnecessary integer ops, especially since
there kernels are integer bound.

Increase number of iterations for benchmarks for
better smoothing.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-13 17:47:41 -08:00
Hinrik Snær Guðmundsson
be6e9d6a9f Fixed wording in extensions.rst (#678)
changed "learn how add" -> "learn how to add"
2024-02-13 08:39:02 -08:00
Gabrijel Boduljak
e54cbb7ba6 Pooling layers (#357)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-12 22:08:13 -08:00
Angelos Katharopoulos
40c108766b Quantized matmul fix (#677)
* Fix qmv for small or unaligned matrices

* Fix qmm
2024-02-12 18:54:21 -08:00
Mike Drob
4cc70290f7 PR Builder Workflow (#659) 2024-02-12 17:47:21 -08:00
Awni Hannun
74caa68d02 nit in readme (#675) 2024-02-12 12:25:04 -08:00
Awni Hannun
3756381358 Faster bfloat quantized mat-vec and vec-mat (#663) 2024-02-11 21:53:16 -08:00
Awni Hannun
d12573daa6 quote file name (#670) 2024-02-11 10:33:30 -08:00
Nripesh Niketan
0dbc4c7547 feat: Update pre-commit-config.yaml (#667) 2024-02-11 06:08:20 -08:00
Vijay Krish
06072601ce Scatter optimization : Eliminate 64b integer divide. (#662)
Launch 2D grid to eliminate divide and mod in device code,
since 64b integer division is very expensive.

Github Issue #506

Co-authored-by: Vijay Krishnamoorthy <vijay_krish@apple.com>
2024-02-10 08:49:51 -08:00
Angelos Katharopoulos
11d2c8f7a1 Linux build for CI of other packages (#660) 2024-02-09 18:17:04 -08:00
Awni Hannun
7f3f8d8f8d Fix the softmax fix (#661) 2024-02-09 17:02:13 -08:00
Awni Hannun
b96be943dc bug fix (#658) 2024-02-09 16:50:45 -08:00
Abdussamet Türker
b670485185 Remainder negative numerator bug fixed (#641)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-09 16:49:14 -08:00
Diogo
b57bd0488d Metadata support for safetensors (#639)
* metadata support for safetensors

* aliases making it alittle more readable

* addressing comments

* python binding tests
2024-02-08 19:33:15 -08:00
Angelos Katharopoulos
221f8d3fc2 Bump the version to 0.2 (#656) 2024-02-08 11:27:12 -08:00
Awni Hannun
5c03efaf29 Compile docs (#653)
* compile docs

* docs nits + comments
2024-02-08 11:21:50 -08:00
LeonEricsson
7dccd42133 updated calls to use loc &scale (#643) 2024-02-08 09:01:59 -08:00
Awni Hannun
1b97b2958b Compile with capture (#629)
* Simple kernel generation

* Remove the generate kernel from graph_utils

* fix multi-output with compile

* fuse with stopgrad

* v1 input, output capture in compile

* cleanup tree update with visitor update

* nit

* remove todo

* state for model, optional explicit init and more pure optimizer steps

* move learning rate to state

* add lr to opt state, some fixes in capture

* fix optim

* update tuple of containers as well

* fix stream for compiled output

* rng state for compile

* nit

* updates and comments

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-02-07 17:29:22 -08:00
Awni Hannun
e5e816a5ef fix sequential with empty modules at end (#647) 2024-02-07 13:22:27 -08:00
Angelos Katharopoulos
28eac18571 Kernel generation (#614)
Generate reusable element-wise kernels given a computation graph.
2024-02-07 13:15:59 -08:00
Noah Farr
5fd11c347d Add loc and scale to random.normal (#638)
* Add loc and scale to random.normal

* Add tests for loc and scale for random.normal

* Run pre-commit hooks

* Fix code review
2024-02-07 11:49:59 -08:00
Aryan Gupta
ef73393a19 Feat: Add weights argument in BCE Loss and tests (#620) 2024-02-07 09:39:52 -08:00
Angelos Katharopoulos
ea406d5e33 CI change (#645)
* CI update

* Skip large binary test for now

* Upgrade pip

* Add proper env variable skipping

* Update the CI

* Fix workflow name

* Set the low memory flag for the tests

* Change build process

* Add pip upgrade

* Use a venv

* Add a missing env activate

* Add setuptools

* Add twine upload back

* Re-enable automatic release builds
2024-02-07 06:04:34 -08:00
Awni Hannun
146bd69470 Skip compile when transforming (#635)
* skip compile when transforming

* simplify message
2024-02-05 21:28:37 -08:00
Jagrit Digani
316ff490b3 Remove masks from BlockLoader and clear out load case for invalid thread (#634) 2024-02-05 16:00:17 -08:00
Awni Hannun
d40a04f8dc minor fixes (#631)
* minor fixes

* var with ddof >= nelements
2024-02-05 13:27:49 -08:00
Awni Hannun
d75ae52ecd Compile primitive (#571)
* Compiled primitive with basic binary, unary graph-level fusion
2024-02-05 06:51:22 -08:00
Avikant Srivastava
31fea3758e feat: enhancement of the error message for mlx.core.mean (#608)
* add error message
2024-02-05 01:21:49 -08:00
Awni Hannun
e319383ef9 Faster gather (#626)
* faster gather

* update copyright
2024-02-04 17:25:44 -08:00
Awni Hannun
5c3ac52dd7 fix test (#627) 2024-02-04 16:18:03 -08:00
David Koski
ebfd3618b0 fixes for building and running on iOS (#619)
* fixes for building and running on iOS

* per suggestion just use Accelerate
2024-02-04 12:29:17 -08:00
Avikant Srivastava
11a9fd40f0 fix: handle linspace function when num is 1 (#602)
* fix: handle linspace function when num is 1

* add comment

* fix test case

* remove breakpoint
2024-02-04 11:03:49 -08:00
Daniel Strobusch
4fd2fb84a6 make python array SupportsAbs conform (like numpy) (#624) 2024-02-04 09:31:02 -08:00
Daniel Strobusch
9852af1a19 fix "shape" docstring. (#623) 2024-02-04 09:21:22 -08:00
minghuaw
16750f3c51 Fix typo in CMakeLists.txt (#616) 2024-02-03 05:59:26 -08:00
Awni Hannun
95b5fb8245 minor changes (#613) 2024-02-02 11:48:35 -08:00
AtomicVar
83f63f2184 Add Margin Ranking Loss (#536) 2024-02-02 10:57:31 -08:00
Awni Hannun
cb6156d35d Fix eval in trace bugs (#612)
* Fix eval in trace bugs

* comment nit
2024-02-02 09:57:12 -08:00
Piotr Rybiec
506d43035c typo fix (#607) 2024-02-01 17:39:55 -08:00
134 changed files with 6715 additions and 1721 deletions

View File

@@ -1,5 +1,8 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters:
nightly_build:
type: boolean
@@ -7,6 +10,9 @@ parameters:
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
jobs:
linux_build_and_test:
@@ -57,17 +63,18 @@ jobs:
command: ./build/tests/tests
mac_build_and_test:
machine: true
resource_class: ml-explore/m-builder
macos:
xcode: "15.2.0"
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=3.9
conda activate runner-env
brew install python@3.9
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
@@ -78,203 +85,158 @@ jobs:
- run:
name: Install Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
- run:
name: Generate package stubs
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
source env/bin/activate
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# eval "$(conda shell.bash hook)"
# conda activate runner-env
# cd examples/extensions && python -m pip install .
# cd examples/extensions && python3.11 -m pip install .
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j
- run:
name: Run CPP tests
command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
DEVICE=cpu ./build/tests/tests
build_release:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
macos_version:
xcode_version:
type: string
default: "14"
default: "15.2.0"
build_env:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
brew install python@<< parameters.python_version >>
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install twine
# TODO: Update build system to switch away from setup.py develop
pip install build
- run:
name: Install Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
PYPI_RELEASE=1 \
source env/bin/activate
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py develop
pip install . -v
- run:
name: Generate package stubs
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
source env/bin/activate
python setup.py generate_stubs
- run:
name: Publish Python package
name: Build Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
PYPI_RELEASE=1 \
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel
twine upload dist/* --repository mlx
python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
source env/bin/activate
twine upload dist/*
- store_artifacts:
path: dist/
build_dev_release:
machine: true
resource_class: ml-explore/m-builder
build_linux_test_release:
parameters:
python_version:
type: string
default: "3.9"
macos_version:
extra_env:
type: string
default: "14"
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
steps:
- checkout
- run:
name: Install dependencies
name: Build wheel
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
PYTHON=python<< parameters.python_version >>
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install twine
- run:
name: Install Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
DEV_RELEASE=1 \
pip install auditwheel
pip install patchelf
pip install build
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py develop
- run:
name: Generate package stubs
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
pip install . -v
python setup.py generate_stubs
- run:
name: Publish Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
DEV_RELEASE=1 \
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel
twine upload dist/* --repository mlx
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- store_artifacts:
path: dist/
build_package:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
macos_version:
type: string
default: "14"
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
pip install twine
- run:
name: Install Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py develop
- run:
name: Generate package stubs
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
python setup.py generate_stubs
- run:
name: Build package distribution
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel
- store_artifacts:
path: dist/
path: wheelhouse/
workflows:
build_and_test:
when:
and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- linux_build_and_test
- mac_build_and_test
- linux_build_and_test
- build_release:
filters:
tags:
@@ -284,20 +246,53 @@ workflows:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
macos_version: ["13", "14"]
xcode_version: ["14.3.1", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
when: << pipeline.parameters.nightly_build >>
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
jobs:
- build_package:
- build_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
macos_version: ["13", "14"]
xcode_version: ["14.3.1", "15.2.0"]
weekly_build:
when: << pipeline.parameters.weekly_build >>
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
jobs:
- build_dev_release:
- build_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
macos_version: ["13", "14"]
xcode_version: ["14.3.1", "15.2.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_linux_test_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -5,7 +5,7 @@ repos:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.12.1
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort

View File

@@ -10,8 +10,8 @@ MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.1.0)
set(MLX_VERSION 0.3.0)
endif()
# --------------------- Processor tests -------------------------
@@ -123,8 +123,8 @@ else()
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib" ${BLAS_LIBRARIES})
message(STATUS "Blas incclude" ${BLAS_INCLUDE_DIRS})
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
find_package(LAPACK REQUIRED)
@@ -134,7 +134,7 @@ else()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include)
message(STATUS "Lapack lib" ${LAPACK_LIBRARIES})
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})

View File

@@ -1,4 +1,4 @@
include CMakeLists.txt
recursive-include mlx/ *
include python/src/*
python/mlx/py.typed # support type hinting as in PEP-561
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -6,8 +6,8 @@
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning on Apple silicon, brought to you
by Apple machine learning research.
MLX is an array framework for machine learning research on Apple silicon,
brought to you by Apple machine learning research.
Some key features of MLX include:

View File

@@ -80,10 +80,8 @@ if __name__ == "__main__":
_filter = make_predicate(args.filter, args.negative_filter)
if args.mlx_dtypes:
compare_filtered = (
lambda x: compare_mlx_dtypes(
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
)
compare_filtered = lambda x: (
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
if _filter(x)
else None
)

View File

@@ -0,0 +1,53 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from time import time
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_gather_mlx(x_shape, idx_shape):
def gather(x, idx):
mx.eval(x[idx])
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
x = mx.random.normal(x_shape).astype(mx.float32)
runtime = measure_runtime(gather, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_gather_torch(x_shape, idx_shape, device):
def gather(x, idx, device):
_ = x[idx]
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
x = torch.randn(x_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
idx_shapes = [(1_000_000,), (100_000,), ()]
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_gather_mlx(x_shape, idx_shape)
benchmark_gather_torch(x_shape, idx_shape, device=device)

View File

@@ -0,0 +1,35 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def time_rope():
rope = nn.RoPE(4096)
# vec
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_mat, x)
if __name__ == "__main__":
time_rope()

View File

@@ -0,0 +1,56 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
def scatter(dst, x, idx):
dst[idx] = x
mx.eval(dst)
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
x = mx.random.normal(x_shape).astype(mx.float32)
dst = mx.random.normal(dst_shape).astype(mx.float32)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
def gather(dst, x, idx, device):
dst[idx] = x
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import time
@@ -20,3 +20,15 @@ def time_fn(fn, *args, **kwargs):
msec = 1e3 * (toc - tic) / num_iters
print(f"{msec:.5f} msec")
def measure_runtime(fn, **kwargs):
# Warmup
for _ in range(5):
fn(**kwargs)
tic = time.time()
iters = 100
for _ in range(iters):
fn(**kwargs)
return (time.time() - tic) * 1000 / iters

1
docs/.gitignore vendored
View File

@@ -1,2 +1,3 @@
src/python/_autosummary*/
src/python/nn/_autosummary*/
src/python/optimizers/_autosummary*/

View File

@@ -1,19 +0,0 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{#{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}#}

View File

@@ -26,6 +26,7 @@ extensions = [
python_use_unqualified_type_names = True
autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = {
"https://docs.python.org/3": None,

View File

@@ -35,7 +35,7 @@ However, you work with vector math libraries often and realize that the
You would really like the part of your applications that does this operation
on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how add
our assumptions on to you, let's also assume that you want to learn how to add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.
@@ -677,9 +677,9 @@ Let's look at the overall directory structure first.
Binding to Python
^^^^^^^^^^^^^^^^^^
We use PyBind11_ to build a Python API for the C++ library. Since bindings
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
are already provided, adding our :meth:`axpby` becomes very simple!
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple!
.. code-block:: C++
@@ -927,18 +927,18 @@ Results:
We see some modest improvements right away!
This operation is now good to be used to build other operations,
in :class:`mlx.nn.Module` calls, and also as a part of graph
transformations like :meth:`grad`!
This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`!
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx-examples <code>`_.
The full example code is available in `mlx <code>`_.
.. code: `TODO_LINK/extensions`_
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc

View File

@@ -41,6 +41,7 @@ are the CPU and GPU.
usage/indexing
usage/saving_and_loading
usage/function_transforms
usage/compile
usage/numpy
usage/using_streams

View File

@@ -9,9 +9,10 @@ Devices and Streams
:toctree: _autosummary
Device
Stream
default_device
set_default_device
Stream
default_stream
new_stream
set_default_stream
stream

View File

@@ -10,6 +10,8 @@ Layers
:template: nn-module-template.rst
ALiBi
AvgPool1d
AvgPool2d
BatchNorm
Conv1d
Conv2d
@@ -22,6 +24,8 @@ Layers
InstanceNorm
LayerNorm
Linear
MaxPool1d
MaxPool2d
Mish
MultiHeadAttention
PReLU

View File

@@ -18,6 +18,7 @@ Loss Functions
kl_div_loss
l1_loss
log_cosh_loss
margin_ranking_loss
mse_loss
nll_loss
smooth_l1_loss

View File

@@ -11,6 +11,7 @@ Module
:toctree: _autosummary
Module.training
Module.state
.. rubric:: Methods

View File

@@ -29,20 +29,8 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
.. currentmodule:: mlx.optimizers
.. toctree::
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
OptimizerState
Optimizer
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers

View File

@@ -0,0 +1,20 @@
.. _common_optimizers:
Common Optimizers
=================
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@@ -0,0 +1,23 @@
Optimizer
=========
.. currentmodule:: mlx.optimizers
.. autoclass:: Optimizer
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Optimizer.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Optimizer.apply_gradients
Optimizer.init
Optimizer.update

View File

@@ -0,0 +1,13 @@
.. _schedulers:
Schedulers
==========
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
step_decay
exponential_decay
cosine_decay

View File

@@ -9,6 +9,9 @@ Transforms
:toctree: _autosummary
eval
compile
disable_compile
enable_compile
grad
value_and_grad
jvp

430
docs/src/usage/compile.rst Normal file
View File

@@ -0,0 +1,430 @@
.. _compile:
Compilation
===========
.. currentmodule:: mlx.core
MLX has a :func:`compile` function transformation which compiles computation
graphs. Function compilation results in smaller graphs by merging common work
and fusing certain operations. In many cases this can lead to big improvements
in run-time and memory use.
Getting started with :func:`compile` is simple, but there are some edge cases
that are good to be aware of for more complex graphs and advanced usage.
Basics of Compile
-----------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
function multiple times will not initiate a new compilation. This means you
should typically compile functions that you plan to use more than once.
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
compiled_fun = mx.compile(fun)
# Compiled here
compiled_fun(x, y)
# Not compiled again
compiled_fun(x, y)
# Not compiled again
mx.compile(fun)(x, y)
There are some important cases to be aware of that can cause a function to
be recompiled:
* Changing the shape or number of dimensions
* Changing the type of any of the inputs
* Changing the number of inputs to the function
In certain cases only some of the compilation stack will be rerun (for
example when changing the shapes) and in other cases the full compilation
stack will be rerun (for example when changing the types). In general you
should avoid compiling functions too frequently.
Another idiom to watch out for is compiling functions which get created and
destroyed frequently. This can happen, for example, when compiling an anonymous
function in a loop:
.. code-block:: python
a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
Example Speedup
---------------
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
Transformer-based models. The implementation involves several unary and binary
element-wise operations:
.. code-block:: python
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
use it with large arrays it will be memory bandwidth bound. However, all of
the operations in the ``gelu`` are fusible into a single kernel with
:func:`compile`. This can speedup both cases considerably.
Let's compare the runtime of the regular function versus the compiled
function. We'll use the following timing helper which does a warm up and
handles synchronization:
.. code-block:: python
import time
def timeit(fun, x):
# warm up
for _ in range(10):
mx.eval(fun(x))
tic = time.perf_counter()
for _ in range(100):
mx.eval(fun(x))
toc = time.perf_counter()
tpi = 1e3 * (toc - tic) / 100
print(f"Time per iteration {tpi:.3f} (ms)")
Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging
---------
When a compiled function is first called, it is traced with placeholder
inputs. This means you can't evaluate arrays (for example to print their
contents) inside compiled functions.
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Crash
return mx.exp(z)
fun(mx.array(5.0))
For debugging, inspecting arrays can be helpful. One way to do that is to
globally disable compilation using the :func:`disable_compile` function or
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
``fun`` is compiled:
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Okay
return mx.exp(z)
mx.disable_compile()
fun(mx.array(5.0))
Pure Functions
--------------
Compiled functions are intended to be *pure*; that is they should not have side
effects. For example:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)
After the first call of ``fun``, the ``state`` list will hold a placeholder
array. The placeholder does not have any data; it is only used to build the
computation graph. Printing such an array results in a crash.
You have two options to deal with this. The first option is to simply return
``state`` as an output:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
_, state = fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
In some cases returning updated state can be pretty inconvenient. Hence,
:func:`compile` has a parameter to capture implicit outputs:
.. code-block:: python
from functools import partial
state = []
# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
This is particularly useful for compiling a function which includes an update
to a container of arrays, as is commonly done when training the parameters of a
:class:`mlx.nn.Module`.
Compiled functions will also treat any inputs not in the parameter list as
constants. For example:
.. code-block:: python
state = [mx.array(1.0)]
@mx.compile
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
In order to have the change of state reflected in the outputs of ``fun`` you
again have two options. The first option is to simply pass ``state`` as input
to the function. In some cases this can be pretty inconvenient. Hence,
:func:`compile` also has a parameter to capture implicit inputs:
.. code-block:: python
from functools import partial
state = [mx.array(1.0)]
# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
of a common setup: training a model with :obj:`mlx.nn.Module` using an
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Perform 10 steps of gradient descent
for it in range(10):
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
# Perform 10 steps of gradient descent
for it in range(10):
loss = step(x, y)
# Evaluate the model and optimizer state
mx.eval(state)
print(loss)
.. note::
If you are using a module which performs random sampling such as
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
optimizer.state, mx.random.state]``.
.. note::
For more examples of compiling full training graphs checkout the `MLX
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
Transformations with Compile
----------------------------
In MLX function transformations are composable. You can apply any function
transformation to the output of any other function transformation. For more on
this, see the documentation on :ref:`function transforms
<function_transforms>`.
Compiling transformed functions works just as expected:
.. code-block:: python
grad_fn = mx.grad(mx.exp)
compiled_grad_fn = mx.compile(grad_fn)
# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))
# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))
.. note::
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`
the most opportunity to optimize the computation graph:
.. code-block:: python
@mx.compile
def inner(x):
return mx.exp(-mx.abs(x))
def outer(x):
inner(inner(x))
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)

View File

@@ -5,9 +5,12 @@ Function Transforms
.. currentmodule:: mlx.core
MLX uses composable function transformations for automatic differentiation and
vectorization. The key idea behind composable function transformations is that
every transformation returns a function which can be further transformed.
MLX uses composable function transformations for automatic differentiation,
vectorization, and compute graph optimizations. To see the complete list of
function transformations check-out the :ref:`API documentation <transforms>`.
The key idea behind composable function transformations is that every
transformation returns a function which can be further transformed.
Here is a simple example:
@@ -36,10 +39,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep
getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
depth. To see the complete list of function transformations check-out the
:ref:`API documentation <transforms>`. See the following sections for more
information on :ref:`automatic differentiaion <auto diff>` and
:ref:`automatic vectorization <vmap>`.
depth. See the following sections for more information on :ref:`automatic
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
Automatic Differentiation
-------------------------

View File

@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.24)
cmake_minimum_required(VERSION 3.27)
project(mlx_sample_extensions LANGUAGES CXX)
@@ -63,4 +63,4 @@ target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()
endif()

View File

@@ -3,9 +3,10 @@ target_sources(
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp

View File

@@ -82,6 +82,13 @@ array::array(std::initializer_list<float> data)
init(data.begin());
}
array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
dtype)) {
init(data.begin());
}
/* Build an array from a shared buffer */
array::array(
allocator::Buffer data,
@@ -180,7 +187,7 @@ array::ArrayDesc::ArrayDesc(
primitive(std::move(primitive)),
inputs(inputs) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
@@ -197,7 +204,7 @@ array::ArrayDesc::ArrayDesc(
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}

View File

@@ -41,6 +41,9 @@ class array {
/* Special case so empty lists default to float32. */
array(std::initializer_list<float> data);
/* Special case so array({}, type) is an empty array. */
array(std::initializer_list<int> data, Dtype dtype);
template <typename T>
array(
std::initializer_list<T> data,
@@ -121,6 +124,9 @@ class array {
template <typename T>
T item();
template <typename T>
T item() const;
struct ArrayIterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = size_t;
@@ -454,6 +460,18 @@ T array::item() {
return *data<T>();
}
template <typename T>
T array::item() const {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
if (!is_evaled()) {
throw std::invalid_argument(
"item() const can only be called on evaled arrays");
}
return *data<T>();
}
template <typename It>
void array::init(It src) {
set_data(allocator::malloc(size() * size_of(dtype())));

View File

@@ -46,6 +46,9 @@ inline void matmul_cblas_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
@@ -94,6 +97,9 @@ inline void matmul_bnns_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;

View File

@@ -33,10 +33,12 @@ DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT_MULTI(Compiled)
DEFAULT(Concatenate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
@@ -57,8 +59,10 @@ DEFAULT(Minimum)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
DEFAULT(Sigmoid)
@@ -68,8 +72,6 @@ DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)
DEFAULT_MULTI(QRF)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
@@ -291,45 +293,6 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
// TODO: Avoid code duplication with the common backend.
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
a,
b,
out,
RemainderFn{},
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* o, auto n) {
int num_el = n;
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
});
} else {
binary(a, b, out, RemainderFn{});
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];

View File

@@ -274,7 +274,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
if (x.strides()[x.ndim() - 1] == 1) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@@ -3,6 +3,7 @@ target_sources(
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
@@ -10,6 +11,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp

View File

@@ -140,16 +140,34 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
return numerator % denominator;
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = numerator % denominator;
if (r != 0 && (r < 0 != denominator < 0))
r += denominator;
return r;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = std::fmod(numerator, denominator);
if (r != 0 && (r < 0 != denominator < 0)) {
r += denominator;
}
return r;
}
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
return numerator % denominator;
}
};

View File

@@ -0,0 +1,59 @@
// Copyright © 2023-2024 Apple Inc.
#include <queue>
#include "mlx/primitives.h"
namespace mlx::core {
// Build the real tape
std::pair<std::queue<array>, std::vector<array>> trace_to_real(
const std::vector<array>& trace_tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs) {
std::unordered_map<uintptr_t, array> trace_to_real;
for (int i = 0; i < inputs.size(); ++i) {
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
}
std::queue<array> tape;
for (auto& a : trace_tape) {
// Find real inputs
std::vector<array> real_inputs;
for (auto& in : a.inputs()) {
real_inputs.push_back(trace_to_real.at(in.id()));
}
tape.push(
array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)));
trace_to_real.insert({a.id(), tape.back()});
}
std::vector<array> outputs;
for (auto& o : trace_outputs) {
outputs.push_back(trace_to_real.at(o.id()));
}
return {tape, outputs};
}
void Compiled::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Make the a real tape from the tracers
auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs);
// Run the tape
while (!tape.empty()) {
auto a = std::move(tape.front());
tape.pop();
auto outputs = a.outputs();
a.primitive().eval_cpu(a.inputs(), outputs);
a.detach();
}
// Copy results into outputs
for (int o = 0; o < real_outputs.size(); ++o) {
outputs[o].copy_shared_buffer(real_outputs[o]);
}
}
} // namespace mlx::core

View File

@@ -3,7 +3,7 @@
#include <cassert>
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h>
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif

View File

@@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h>
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
@@ -41,7 +41,9 @@ DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT_MULTI(Compiled)
DEFAULT(Concatenate)
DEFAULT(Convolution)
DEFAULT(Copy)
@@ -78,6 +80,7 @@ DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT(Power)
DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce)
@@ -100,8 +103,6 @@ DEFAULT(Subtract)
DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)
DEFAULT_MULTI(QRF)
namespace {
@@ -131,7 +132,9 @@ inline void matmul_common_general(
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;

View File

@@ -5,7 +5,7 @@
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/lapack.h>
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif

View File

@@ -0,0 +1,14 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast.h"
#include "mlx/primitives.h"
namespace mlx::core::fast {
void RoPE::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("NYI");
}
} // namespace mlx::core::fast

View File

@@ -53,7 +53,12 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
if (x.strides().back() == 1) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@@ -1,7 +1,28 @@
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_C_COMPILER}
${CMAKE_SOURCE_DIR}
DEPENDS make_compiled_preamble.sh
kernels/compiled_preamble.h
kernels/unary.h
kernels/binary.h
)
add_custom_target(
compiled_preamble
DEPENDS compiled_preamble.cpp
)
add_dependencies(mlx compiled_preamble)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
@@ -11,10 +32,12 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)
if (NOT MLX_METAL_PATH)

View File

@@ -0,0 +1,484 @@
// Copyright © 2023-2024 Apple Inc.
#include <sstream>
#include "mlx/backend/metal/compiled_preamble.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
inline bool is_static_cast(const Primitive& p) {
return (
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
}
inline auto get_type_string(Dtype d) {
switch (d) {
case float32:
return "float";
case float16:
return "half";
case bfloat16:
return "bfloat16_t";
case bool_:
return "bool";
case int8:
return "int8_t";
case int16:
return "int16_t";
case int32:
return "int32_t";
case int64:
return "int64_t";
case uint8:
return "uint8_t";
case uint16:
return "uint16_t";
case uint32:
return "uint32_t";
case uint64:
return "uint64_t";
default: {
std::ostringstream msg;
msg << "Unsupported compilation type " << d;
throw std::runtime_error(msg.str());
}
}
}
template <typename T>
void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< x.item<T>() << std::setprecision(old_precision);
}
template <typename T>
void print_int_constant(std::ostream& os, const array& x) {
os << x.item<T>();
}
void print_constant(std::ostream& os, const array& x) {
switch (x.dtype()) {
case float32:
return print_float_constant<float>(os, x);
case float16:
return print_float_constant<float16_t>(os, x);
case bfloat16:
return print_float_constant<bfloat16_t>(os, x);
case int8:
return print_int_constant<int8_t>(os, x);
case int16:
return print_int_constant<int16_t>(os, x);
case int32:
return print_int_constant<int32_t>(os, x);
case int64:
return print_int_constant<int64_t>(os, x);
case uint8:
return print_int_constant<uint8_t>(os, x);
case uint16:
return print_int_constant<uint16_t>(os, x);
case uint32:
return print_int_constant<uint32_t>(os, x);
case uint64:
return print_int_constant<uint64_t>(os, x);
case bool_:
os << std::boolalpha << x.item<bool>();
return;
default:
throw std::runtime_error("Unsupported constant type");
}
}
inline std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
std::ostringstream os;
std::ostringstream constant_hasher;
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
a.primitive().print(os);
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << ((x.size() == 1) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
inline void build_kernel(
std::ostream& os,
const std::string& kernel_name,
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous,
int ndim,
bool dynamic_dims) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
// For scalar we shouldn't do the indexing things, just read at 0
auto is_scalar = [](const array& x) { return x.size() == 1; };
NodeNamer namer;
bool add_indices = false;
int cnt = 0;
// Start the kernel
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl
<< "[[kernel]] void " << kernel_name << "(" << std::endl;
// Add the input arguments
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
// Skip constants from the input list
if (is_constant(x)) {
continue;
}
// Scalars and contiguous need no strides
if (is_scalar(x) || contiguous) {
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
} else {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl
<< " constant const size_t* " << xname << "_strides [[buffer("
<< cnt++ << ")]]," << std::endl;
}
}
// Add the output arguments
for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* "
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " constant const size_t* output_strides [[buffer(" << cnt++
<< ")]]," << std::endl
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],"
<< std::endl;
}
if (dynamic_dims) {
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],"
<< std::endl;
}
// The thread index in the whole grid
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
<< " uint3 grid [[threads_per_grid]]) {" << std::endl
<< " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"
<< std::endl;
// Extract the indices per axis to individual uints if we have arrays that
// are broadcasted or transposed
if (add_indices) {
if (!dynamic_dims) {
if (ndim == 1) {
os << " uint index_0 = pos.x;" << std::endl;
} else if (ndim == 2) {
os << " uint index_0 = pos.y;" << std::endl
<< " uint index_1 = pos.x;" << std::endl;
} else if (ndim == 3) {
os << " uint index_0 = pos.z;" << std::endl
<< " uint index_1 = pos.y;" << std::endl
<< " uint index_2 = pos.x;" << std::endl;
} else {
for (int i = 0; i < ndim - 2; i++) {
os << " uint index_" << i << " = (index / uint(output_strides[" << i
<< "])) % output_shape[" << i << "];" << std::endl;
}
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
}
}
}
// Read the inputs in tmps
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ";" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];" << std::endl;
} else if (!dynamic_dims) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[";
os << "index_0 * " << xname << "_strides[0]";
for (int i = 1; i < ndim; i++) {
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
}
os << "];" << std::endl;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[elem_to_loc(index, output_shape, " << xname
<< "_strides, ndim)];" << std::endl;
}
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
<< ";" << std::endl;
}
// Finish the kernel
os << "}" << std::endl;
if (cnt > 31) {
std::ostringstream msg;
msg << "[compile] Too many inputs/outputs fused in the Metal Compile "
<< "primitive which exhausted the available argument buffers for "
<< "the kernel. Please file an issue with the function that results "
<< "in this error. The name of the kernel is '" << kernel_name << "'";
throw std::runtime_error(msg.str());
}
}
void Compiled::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Make the name for the kernel library
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Get the kernel if someone else built it already
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_);
// If not we have to build it ourselves
if (lib == nullptr) {
std::ostringstream kernel;
kernel << metal::get_kernel_preamble() << std::endl;
build_kernel(
kernel,
kernel_lib_ + "_contiguous",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false);
for (int i = 1; i < 8; i++) {
build_kernel(
kernel,
kernel_lib_ + "_strided_" + std::to_string(i),
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false);
}
build_kernel(
kernel,
kernel_lib_ + "_strided_dynamic",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true);
kernel_source_ = kernel.str();
lib = d.get_library(kernel_lib_, kernel_source_);
}
// Allocate space for the outputs
for (auto& out : outputs) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
bool contiguous = true;
for (auto& x : inputs) {
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
x.size() > 1) {
contiguous = false;
break;
}
}
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
std::vector<std::vector<size_t>> initial_strides;
initial_strides.push_back(outputs[0].strides());
std::vector<int> shape;
std::vector<std::vector<size_t>> strides;
if (!contiguous) {
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
// Skip scalar inputs.
if (x.size() <= 1) {
continue;
}
// Broadcast the inputs to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < output_shape.size() - x.ndim(); j++) {
if (output_shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (output_shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
initial_strides.push_back(std::move(xstrides));
}
std::tie(shape, strides) =
collapse_contiguous_dims(output_shape, initial_strides);
}
// Get the kernel from the lib
int ndim = shape.size();
bool dynamic = ndim >= 8;
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) {
if (dynamic) {
kernel_name += "dynamic";
} else {
kernel_name += std::to_string(shape.size());
}
}
auto kernel = d.get_kernel(kernel_name, lib);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Put the inputs in
int cnt = 0;
int stride_idx = 1; // idx 0 is the output strides
for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
set_array_buffer(compute_encoder, x, cnt++);
if (!contiguous && x.size() > 1) {
compute_encoder->setBytes(
strides[stride_idx].data(),
strides[stride_idx].size() * sizeof(size_t),
cnt++);
stride_idx++;
}
}
// Put the outputs in
for (auto& x : outputs) {
set_array_buffer(compute_encoder, x, cnt++);
}
// Put the output shape and strides in
if (!contiguous) {
compute_encoder->setBytes(
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
}
// Put the number of dims in if it is dynamic
if (dynamic) {
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
}
// Launch the kernel
if (contiguous) {
size_t nthreads = outputs[0].size();
MTL::Size grid_dims(nthreads, 1, 1);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,9 @@
// Copyright © 2023-24 Apple Inc.
#pragma once
namespace mlx::core::metal {
const char* get_kernel_preamble();
}

View File

@@ -26,7 +26,8 @@ static constexpr const char* default_mtllib_path = METAL_PATH;
auto load_device() {
auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0));
auto device = static_cast<MTL::Device*>(devices->object(0))
?: MTL::CreateSystemDefaultDevice();
if (!device) {
throw std::runtime_error("Failed to load device");
}
@@ -214,15 +215,6 @@ MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
return eit->second;
}
MTL::ArgumentEncoder* Device::argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const {
// NB array here is already autoreleased but the returned argument
// encoder is owned by the caller and must be released/autoreleased
NS::Array* arg_desc_arr = NS::Array::array(
reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size());
return device_->newArgumentEncoder(arg_desc_arr);
}
void Device::register_library(
const std::string& lib_name,
const std::string& lib_path) {
@@ -413,6 +405,11 @@ MTL::ComputePipelineState* Device::get_kernel_(
return kernel;
}
MTL::Library* Device::get_library(const std::string& name) {
auto it = library_map_.find(name);
return (it != library_map_.end()) ? it->second : nullptr;
}
MTL::Library* Device::get_library(
const std::string& name,
const std::string& source,

View File

@@ -62,6 +62,8 @@ class Device {
const std::function<std::string(const std::string&)>& lib_path_func =
get_colocated_mtllib_path);
MTL::Library* get_library(const std::string& name);
MTL::Library* get_library(
const std::string& name,
const std::string& source_string,

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
@@ -39,114 +39,75 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
if (idx_ndim <= 1) {
kname << "_" << idx_ndim;
}
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
size_t slice_size = 1;
for (auto s : slice_sizes_) {
slice_size *= s;
}
size_t ndim = src.ndim();
size_t nthreads = out.size();
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
// Launch 2D grid of threads: indices x slice
size_t dim0 = out.size() / slice_size;
size_t dim1 = slice_size;
auto group_dims = get_block_dims(dim0, dim1, 1);
MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
compute_encoder->setComputePipelineState(kernel);
// Make the argument buffer to store the indices for the
// `Indices` struct in kernels/indexing.metal
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[0]->setIndex(0);
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[0]->setArrayLength(nidx);
// Shapes
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[1]->setIndex(nidx + 1);
// Strides
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[2]->setIndex(nidx + 2);
// Indices ndim
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
arg_descs[3]->setIndex(nidx + 3);
// Get the argument encoder
auto arg_enc = d.argument_encoder(arg_descs);
// Allocate and fill buffers for shapes and strides
int idx_ndim = nidx ? inputs[1].ndim() : 0;
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy(
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end(),
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
std::copy(
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end(),
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
inputs[i + 1].strides().end());
}
// Allocate the argument buffer
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
// Register data with the encoder
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
for (int i = 0; i < nidx; ++i) {
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
}
if (idx_ndim > 0) {
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
MTL::ResourceUsageRead);
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
MTL::ResourceUsageRead);
}
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
// Set all the buffers
set_array_buffer(compute_encoder, src, 0);
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
set_array_buffer(compute_encoder, out, 1);
// Set source info
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
// Set index info
//
// We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Launch grid
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Cleanup temporaries
arg_enc->release();
d.get_command_buffer(s.index)->addCompletedHandler(
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
allocator::free(arg_buf);
allocator::free(idx_shapes_buf);
allocator::free(idx_strides_buf);
});
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -211,82 +172,35 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
thread_group_size = nthreads;
}
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
// Make the argument buffer to store the indices for the
// `Indices` struct in kernels/indexing.metal
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[0]->setIndex(0);
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[0]->setArrayLength(nidx);
// Shapes
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[1]->setIndex(nidx + 1);
// Strides
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[2]->setIndex(nidx + 2);
// Indices ndim
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
arg_descs[3]->setIndex(nidx + 3);
// Get the argument encoder
auto arg_enc = d.argument_encoder(arg_descs);
// Allocate and fill buffers for shapes and strides
// Collect all idx shapes and strides into one place
int idx_ndim = nidx ? inputs[1].ndim() : 0;
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
for (int i = 0; i < nidx; ++i) {
std::copy(
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end(),
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
std::copy(
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end(),
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
inputs[i + 1].strides().end());
}
// Allocate the argument buffer
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
// Set all the buffers
set_array_buffer(compute_encoder, upd, 1);
set_array_buffer(compute_encoder, out, 2);
// Register data with the encoder
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
for (int i = 0; i < nidx; ++i) {
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
}
if (idx_ndim > 0) {
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
MTL::ResourceUsageRead);
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
MTL::ResourceUsageRead);
}
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
// Set update info
size_t upd_ndim = upd.ndim();
size_t upd_size = 1;
for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i);
}
set_array_buffer(compute_encoder, upd, 1);
set_array_buffer(compute_encoder, out, 2);
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
@@ -301,6 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
@@ -316,16 +231,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
}
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Cleanup temporaries
arg_enc->release();
d.get_command_buffer(s.index)->addCompletedHandler(
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
allocator::free(arg_buf);
allocator::free(idx_shapes_buf);
allocator::free(idx_strides_buf);
});
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@@ -6,6 +6,7 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
)
@@ -22,11 +23,13 @@ set(
"quantized"
"random"
"reduce"
"rope"
"scan"
"softmax"
"sort"
"unary"
"indexing"
"gather"
"scatter"
)
function(build_kernel_base TARGET SRCFILE DEPS)

View File

@@ -0,0 +1,231 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
if (metal::isnan(x) || metal::isnan(y)) {
return metal::numeric_limits<T>::quiet_NaN();
}
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf)
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::max(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x > y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x > y ? x : y;
}
};
struct Minimum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::min(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x < y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x < y ? x : y;
}
};
struct Multiply {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
};
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
};
};

View File

@@ -1,176 +1,6 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
struct Add {
template <typename T> T operator()(T x, T y) { return x + y; }
};
struct Divide {
template <typename T> T operator()(T x, T y) { return x / y; }
};
struct Remainder {
template <typename T> T operator()(T x, T y) { return x % y; }
template <> float operator()(float x, float y) { return fmod(x, y); }
template <> half operator()(half x, half y) { return fmod(x, y); }
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
};
struct Equal {
template <typename T> bool operator()(T x, T y) { return x == y; }
};
struct NaNEqual {
template <typename T> bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real)
&& metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
struct Greater {
template <typename T> bool operator()(T x, T y) { return x > y; }
};
struct GreaterEqual {
template <typename T> bool operator()(T x, T y) { return x >= y; }
};
struct Less {
template <typename T> bool operator()(T x, T y) { return x < y; }
};
struct LessEqual {
template <typename T> bool operator()(T x, T y) { return x <= y; }
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
if (metal::isnan(x) || metal::isnan(y)) {
return metal::numeric_limits<T>::quiet_NaN();
}
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf) ? maxval :
(maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::max(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x > y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x > y ? x : y;
}
};
struct Minimum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::min(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x < y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x < y ? x : y;
}
};
struct Multiply {
template <typename T> T operator()(T x, T y) { return x * y; }
};
struct NotEqual {
template <typename T> bool operator()(T x, T y) { return x != y; }
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T> T operator()(T x, T y) { return x - y; }
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) { return x && y; };
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) { return x || y; };
};
#include "mlx/backend/metal/kernels/binary.h"
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_s2s(

View File

@@ -14,10 +14,29 @@ struct FloorDivide {
};
struct Remainder {
template <typename T> T operator()(T x, T y) { return x % y; }
template <> float operator()(float x, float y) { return fmod(x, y); }
template <> half operator()(half x, half y) { return fmod(x, y); }
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <> complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
template <typename T, typename U, typename Op1, typename Op2>

View File

@@ -0,0 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/binary.h"
#include "mlx/backend/metal/kernels/unary.h"

View File

@@ -121,5 +121,11 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
if (real != 0 && (real < 0 != b.real < 0)) {
real += b.real;
}
if (imag != 0 && (imag < 0 != b.imag < 0)) {
imag += b.imag;
}
return {real, imag};
}

View File

@@ -0,0 +1,187 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_atomic>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Gather kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T *src [[buffer(0)]],
device T *out [[buffer(1)]],
const constant int *src_shape [[buffer(2)]],
const constant size_t *src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int *slice_sizes [[buffer(5)]],
const constant int *axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x;
auto ind_offset = index.y;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = ind_idx * indices.strides[indices.ndim * i];
} else {
idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(
ind_offset, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx];
}
#define make_gather_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
[[kernel]] void gather( \
const device T *src [[buffer(0)]], \
device T *out [[buffer(1)]], \
const constant int *src_shape [[buffer(2)]], \
const constant size_t *src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int *slice_sizes [[buffer(5)]], \
const constant int *axes [[buffer(6)]], \
const constant int *idx_shapes [[buffer(7)]], \
const constant size_t *idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(IdxT) \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]) { \
\
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, \
idx_shapes, \
idx_strides, \
idx_ndim}; \
\
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
src, \
out, \
src_shape, \
src_strides, \
src_ndim, \
slice_sizes, \
axes, \
idxs, \
index, \
grid_dim); \
}
#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
make_gather(0)
make_gather(1)
make_gather(2)
make_gather(3)
make_gather(4)
make_gather(5)
make_gather(6)
make_gather(7)
make_gather(8)
make_gather(9)
make_gather(10)
/////////////////////////////////////////////////////////////////////
// Gather instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
[[kernel]] void gather<src_t, idx_t, nidx, nd>( \
const device src_t *src [[buffer(0)]], \
device src_t *out [[buffer(1)]], \
const constant int *src_shape [[buffer(2)]], \
const constant size_t *src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int *slice_sizes [[buffer(5)]], \
const constant int *axes [[buffer(6)]], \
const constant int *idx_shapes [[buffer(7)]], \
const constant size_t *idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(idx_t) \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]);
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name)
#define instantiate_gather4(name, src_t, idx_t, nidx) \
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
instantiate_gather5(name, src_t, idx_t, nidx, 2, )
// Special for case NIDX=0
instantiate_gather4("bool_", bool, bool, 0)
instantiate_gather4("uint8", uint8_t, bool, 0)
instantiate_gather4("uint16", uint16_t, bool, 0)
instantiate_gather4("uint32", uint32_t, bool, 0)
instantiate_gather4("uint64", uint64_t, bool, 0)
instantiate_gather4("int8", int8_t, bool, 0)
instantiate_gather4("int16", int16_t, bool, 0)
instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10)
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
instantiate_gather(uint8, uint8_t)
instantiate_gather(uint16, uint16_t)
instantiate_gather(uint32, uint32_t)
instantiate_gather(uint64, uint64_t)
instantiate_gather(int8, int8_t)
instantiate_gather(int16, int16_t)
instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t)

View File

@@ -0,0 +1,54 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Indexing utils
/////////////////////////////////////////////////////////////////////
template <typename IdxT, int NIDX>
struct Indices {
const array<const device IdxT*, NIDX> buffers;
const constant int* shapes;
const constant size_t* strides;
const int ndim;
};
template <typename IdxT>
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
if (is_unsigned_v<IdxT>) {
return idx;
} else {
return (idx < 0) ? idx + size : idx;
}
}
#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
#define IDX_ARG_0(idx_t)
#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
#define IDX_ARR_N(n) idx##n,
#define IDX_ARR_0()
#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)

View File

@@ -1,254 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <metal_atomic>
#include <metal_texture>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/reduce.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Gather kernel
/////////////////////////////////////////////////////////////////////
template <typename IdxT, int NIDX>
struct Indices {
const array<device IdxT*, NIDX> buffers [[id(0)]];
device int* shapes [[id(NIDX + 1)]];
device size_t* strides [[id(NIDX + 2)]];
const int ndim [[id(NIDX + 3)]];
};
template <typename IdxT>
inline size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
template <>
inline size_t offset_neg_idx(bool idx, size_t) {
return idx;
}
template <>
inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
}
template <typename T, typename IdxT, int NIDX>
[[kernel]] void gather(
const device T *src [[buffer(0)]],
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
device T *out [[buffer(2)]],
const device int *src_shape [[buffer(3)]],
const device size_t *src_strides [[buffer(4)]],
const device size_t& src_ndim [[buffer(5)]],
const device int *slice_sizes [[buffer(6)]],
const device size_t& slice_size [[buffer(7)]],
const device int *axes [[buffer(8)]],
uint gid [[thread_position_in_grid]]) {
auto ind_idx = gid / slice_size;
auto ind_offset = gid % slice_size;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(
ind_offset, slice_sizes, src_strides, src_ndim);
out[gid] = src[src_idx + src_offset];
}
#define instantiate_gather4(name, src_type, ind_type, nindex) \
template [[host_name("gather" name "_" #nindex)]] \
[[kernel]] void gather( \
const device src_type *src [[buffer(0)]], \
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
device src_type *out [[buffer(2)]], \
const device int *src_shape [[buffer(3)]], \
const device size_t *src_strides [[buffer(4)]], \
const device size_t& src_ndim [[buffer(5)]], \
const device int *slice_sizes [[buffer(6)]], \
const device size_t& slice_size [[buffer(7)]], \
const device int* axes [[buffer(8)]], \
uint gid [[thread_position_in_grid]]);
// Special for case NIDX=0
instantiate_gather4("bool_", bool, bool, 0)
instantiate_gather4("uint8", uint8_t, bool, 0)
instantiate_gather4("uint16", uint16_t, bool, 0)
instantiate_gather4("uint32", uint32_t, bool, 0)
instantiate_gather4("uint64", uint64_t, bool, 0)
instantiate_gather4("int8", int8_t, bool, 0)
instantiate_gather4("int16", int16_t, bool, 0)
instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10)
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
instantiate_gather(uint8, uint8_t)
instantiate_gather(uint16, uint16_t)
instantiate_gather(uint32, uint32_t)
instantiate_gather(uint64, uint64_t)
instantiate_gather(int8, int8_t)
instantiate_gather(int16, int16_t)
instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t)
/////////////////////////////////////////////////////////////////////
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX>
[[kernel]] void scatter(
const device Indices<IdxT, NIDX>& indices [[buffer(0)]],
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const device int *upd_shape [[buffer(3)]],
const device size_t *upd_strides [[buffer(4)]],
const device size_t& upd_ndim [[buffer(5)]],
const device size_t& upd_size [[buffer(6)]],
const device int *out_shape [[buffer(7)]],
const device size_t *out_strides [[buffer(8)]],
const device size_t& out_ndim [[buffer(9)]],
const device int* axes [[buffer(10)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid / upd_size;
auto ind_offset = gid % upd_size;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
}
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
template [[host_name("scatter" name "_" #nindex)]] \
[[kernel]] void scatter<type, ind_type, op_type, nindex>( \
const device Indices<ind_type, nindex>& indices [[buffer(0)]], \
const device type *updates [[buffer(1)]], \
device mlx_atomic<type> *out [[buffer(2)]], \
const device int *upd_shape [[buffer(3)]], \
const device size_t *upd_strides [[buffer(4)]], \
const device size_t& upd_ndim [[buffer(5)]], \
const device size_t& upd_size [[buffer(6)]], \
const device int *out_shape [[buffer(7)]], \
const device size_t *out_strides [[buffer(8)]], \
const device size_t& out_ndim [[buffer(9)]], \
const device int* axes [[buffer(10)]], \
uint gid [[thread_position_in_grid]]);
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
#define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10)
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
instantiate_scatter2(#name "uint16", type, uint16_t) \
instantiate_scatter2(#name "uint32", type, uint32_t) \
instantiate_scatter2(#name "uint64", type, uint64_t) \
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t)
// TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t)
instantiate_scatter_nd0(uint16, uint16_t)
instantiate_scatter_nd0(uint32, uint32_t)
instantiate_scatter_nd0(int8, int8_t)
instantiate_scatter_nd0(int16, int16_t)
instantiate_scatter_nd0(int32, int32_t)
instantiate_scatter_nd0(float16, half)
instantiate_scatter_nd0(float32, float)
instantiate_scatter_nd0(bfloat16, bfloat16_t)
instantiate_scatter(bool_, bool)
instantiate_scatter(uint8, uint8_t)
instantiate_scatter(uint16, uint16_t)
instantiate_scatter(uint32, uint32_t)
instantiate_scatter(int8, int8_t)
instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half)
instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t)

View File

@@ -15,6 +15,14 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32;
template <typename T> struct AccT {
typedef T acc_t;
};
template <> struct AccT<bfloat16_t> {
typedef float acc_t;
};
template <typename T, const int BM, const int BN, const int group_size, const int bits>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
@@ -31,21 +39,23 @@ template <typename T, const int BM, const int BN, const int group_size, const in
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
(void)lid;
constexpr int bitmask = (1 << bits) - 1;
constexpr int el_per_thread = 32 / bits;
constexpr int colgroup = BN * el_per_thread;
constexpr int groups_per_block = colgroup / group_size;
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
threadgroup T scales_block[BM * groups_per_block];
threadgroup T biases_block[BM * groups_per_block];
threadgroup T x_block[colgroup];
typedef typename AccT<T>::acc_t U;
threadgroup U scales_block[BM * groups_per_block];
threadgroup U biases_block[BM * groups_per_block];
threadgroup U x_block[colgroup];
thread uint32_t w_local;
thread T result = 0;
thread T scale = 1;
thread T bias = 0;
thread T x_thread[el_per_thread];
thread U result = 0;
thread U scale = 1;
thread U bias = 0;
thread U x_thread[el_per_thread];
// Adjust positions
const int in_vec_size_w = in_vec_size / el_per_thread;
@@ -57,12 +67,19 @@ template <typename T, const int BM, const int BN, const int group_size, const in
x += tid.z * in_vec_size;
y += tid.z * out_vec_size;
if (out_row >= out_vec_size) {
return;
}
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=colgroup) {
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid < simdgroups_fetching_vec) {
x_block[lid] = x[lid + i];
if (simd_gid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<el_per_thread; j++) {
x_block[simd_lid * el_per_thread + j] = x[i + simd_lid * el_per_thread + j];
}
}
if (simd_lid == 0) {
#pragma clang loop unroll(full)
@@ -90,7 +107,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Do all the work.
#pragma clang loop unroll(full)
for (int k=0; k<el_per_thread; k++) {
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
result += (scale * static_cast<U>(w_local & bitmask) + bias) * x_thread[k];
w_local >>= bits;
}
}
@@ -100,7 +117,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Store the result
if (simd_lid == 0) {
y[out_row] = result;
y[out_row] = static_cast<T>(result);
}
}
@@ -129,15 +146,16 @@ template <typename T, const int BM, const int BN, const int group_size, const in
constexpr int colgroup = BN * el_per_int;
constexpr int groups_per_block = colgroup / group_size;
threadgroup T scales_block[BM * groups_per_block];
threadgroup T biases_block[BM * groups_per_block];
threadgroup T x_block[BM];
typedef typename AccT<T>::acc_t U;
threadgroup U scales_block[BM * groups_per_block];
threadgroup U biases_block[BM * groups_per_block];
threadgroup U x_block[BM];
thread uint32_t w_local;
thread T result[el_per_int] = {0};
thread T scale = 1;
thread T bias = 0;
thread T x_local = 0;
thread U result[el_per_int] = {0};
thread U scale = 1;
thread U bias = 0;
thread U x_local = 0;
// Adjust positions
const int out_vec_size_w = out_vec_size / el_per_int;
@@ -186,7 +204,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Do all the work.
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local;
result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
w_local >>= bits;
}
}
@@ -201,7 +219,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
y[k] = result[k];
y[k] = static_cast<T>(result[k]);
}
}
}
@@ -240,7 +258,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
threadgroup T scales_block[BN * groups_per_block];
threadgroup T biases_block[BN * groups_per_block];
threadgroup T Xs[BM * BK];
@@ -303,7 +320,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
if (y_col + offset_col < N) {
if (y_row + offset_row < N) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
@@ -418,8 +435,9 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
for (int k=0; k<K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the x tile
if (num_els < BM) {
loader_x.load_safe(short2(BK, num_els));
short num_k = min(BK, K - k);
if (num_els < BM || num_k < BK) {
loader_x.load_safe(short2(num_k, num_els));
} else {
loader_x.load_unsafe();
}
@@ -447,7 +465,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Load the w tile
{
if (k + BK >= K) {
if (num_k < BK) {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);

View File

@@ -0,0 +1,68 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional>
[[kernel]] void rope(
const device T *in [[buffer(0)]],
device T * out [[buffer(1)]],
constant const size_t strides[3],
constant const int& offset,
constant const float& base,
constant const float& scale,
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Compute the input and output indices
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
out_index_2 = out_index_1 + grid.x;
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
}
// Figure out L and d.
float L = scale * static_cast<float>(pos.y + offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float rx1 = x1 * costheta - x2 * sintheta;
float rx2 = x1 * sintheta + x2 * costheta;
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
}
#define instantiate_rope(name, type, traditional) \
template [[host_name("rope_" #name)]] \
[[kernel]] void rope<type, traditional>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const size_t strides[3], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
instantiate_rope(traditional_float16, half, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
instantiate_rope(traditional_float32, float, true)
instantiate_rope(float16, half, false)
instantiate_rope(bfloat16, bfloat16_t, false)
instantiate_rope(float32, float, false)

View File

@@ -0,0 +1,194 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_atomic>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/reduce.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_impl(
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const constant int *upd_shape [[buffer(3)]],
const constant size_t *upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int *out_shape [[buffer(7)]],
const constant size_t *out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid.y;
auto ind_offset = gid.x;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
}
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter( \
const device T *updates [[buffer(1)]], \
device mlx_atomic<T> *out [[buffer(2)]], \
const constant int *upd_shape [[buffer(3)]], \
const constant size_t *upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int *out_shape [[buffer(7)]], \
const constant size_t *out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int *idx_shapes [[buffer(11)]], \
const constant size_t *idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \
\
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, \
idx_shapes, \
idx_strides, \
idx_ndim}; \
\
return scatter_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
upd_shape, \
upd_strides, \
upd_ndim, \
upd_size, \
out_shape, \
out_strides, \
out_ndim, \
axes, \
idxs, \
gid); \
}
#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
make_scatter(0)
make_scatter(1)
make_scatter(2)
make_scatter(3)
make_scatter(4)
make_scatter(5)
make_scatter(6)
make_scatter(7)
make_scatter(8)
make_scatter(9)
make_scatter(10)
/////////////////////////////////////////////////////////////////////
// Scatter instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter" name "_" #nidx)]] \
[[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \
const device src_t *updates [[buffer(1)]], \
device mlx_atomic<src_t> *out [[buffer(2)]], \
const constant int *upd_shape [[buffer(3)]], \
const constant size_t *upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int *out_shape [[buffer(7)]], \
const constant size_t *out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int *idx_shapes [[buffer(11)]], \
const constant size_t *idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
#define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10)
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
instantiate_scatter2(#name "uint16", type, uint16_t) \
instantiate_scatter2(#name "uint32", type, uint32_t) \
instantiate_scatter2(#name "uint64", type, uint64_t) \
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t)
// TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t)
instantiate_scatter_nd0(uint16, uint16_t)
instantiate_scatter_nd0(uint32, uint32_t)
instantiate_scatter_nd0(int8, int8_t)
instantiate_scatter_nd0(int16, int16_t)
instantiate_scatter_nd0(int32, int32_t)
instantiate_scatter_nd0(float16, half)
instantiate_scatter_nd0(float32, float)
instantiate_scatter_nd0(bfloat16, bfloat16_t)
instantiate_scatter(bool_, bool)
instantiate_scatter(uint8, uint8_t)
instantiate_scatter(uint16, uint16_t)
instantiate_scatter(uint32, uint32_t)
instantiate_scatter(int8, int8_t)
instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half)
instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t)

View File

@@ -89,20 +89,9 @@ struct GEMMKernel {
// Appease the compiler
(void)l;
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
if (!M_aligned) {
short2 tile_dims_A =
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
loader_a.set_mask(tile_dims_A, mask_A);
}
if (!N_aligned) {
short2 tile_dims_B =
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
loader_b.set_mask(tile_dims_B, mask_B);
}
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -110,13 +99,13 @@ struct GEMMKernel {
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(mask_A);
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(mask_B);
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -137,11 +126,8 @@ struct GEMMKernel {
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.set_mask(tile_dims_A_last, mask_A);
loader_b.set_mask(tile_dims_B_last, mask_B);
loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -218,14 +204,8 @@ struct GEMMKernel {
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
loader_a.set_mask(tile_dims_A, mask_A);
loader_b.set_mask(tile_dims_B, mask_B);
loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);

View File

@@ -112,14 +112,8 @@ template <typename T,
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
loader_a.set_mask(tile_dims_A, mask_A);
loader_b.set_mask(tile_dims_B, mask_B);
loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);

View File

@@ -67,24 +67,22 @@ struct BlockLoader {
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void set_mask(
thread const short2& src_tile_dims,
thread bool mask[n_rows][vec_size]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
mask[i][j] =
((bi + i) < src_tile_dims.y) && ((bj + j) < src_tile_dims.x);
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
@@ -117,39 +115,6 @@ struct BlockLoader {
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(const thread bool mask[n_rows][vec_size]) const {
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0, ii = 0; i < BROWS; i += TROWS, ii++) {
simdgroup_barrier(mem_flags::mem_none);
// Use fast thread memory for bound checks
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(mask[ii][j] ? i * src_ld + j : 0)];
}
simdgroup_barrier(mem_flags::mem_none);
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = mask[ii][j] ? tmp_val[j] : T(0);
}
simdgroup_barrier(mem_flags::mem_none);
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;

View File

@@ -0,0 +1,376 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/utils.h"
struct Abs {
template <typename T>
T operator()(T x) {
return metal::abs(x);
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
template <>
complex64_t operator()(complex64_t x) {
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
};
};
struct ArcCos {
template <typename T>
T operator()(T x) {
return metal::precise::acos(x);
};
};
struct ArcCosh {
template <typename T>
T operator()(T x) {
return metal::precise::acosh(x);
};
};
struct ArcSin {
template <typename T>
T operator()(T x) {
return metal::precise::asin(x);
};
};
struct ArcSinh {
template <typename T>
T operator()(T x) {
return metal::precise::asinh(x);
};
};
struct ArcTan {
template <typename T>
T operator()(T x) {
return metal::precise::atan(x);
};
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
return metal::precise::atanh(x);
};
};
struct Ceil {
template <typename T>
T operator()(T x) {
return metal::ceil(x);
};
template <>
int8_t operator()(int8_t x) {
return x;
};
template <>
int16_t operator()(int16_t x) {
return x;
};
template <>
int32_t operator()(int32_t x) {
return x;
};
template <>
int64_t operator()(int64_t x) {
return x;
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
};
struct Cos {
template <typename T>
T operator()(T x) {
return metal::precise::cos(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)};
};
};
struct Cosh {
template <typename T>
T operator()(T x) {
return metal::precise::cosh(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)};
};
};
struct Erf {
template <typename T>
T operator()(T x) {
return static_cast<T>(erf(static_cast<float>(x)));
};
};
struct ErfInv {
template <typename T>
T operator()(T x) {
return static_cast<T>(erfinv(static_cast<float>(x)));
};
};
struct Exp {
template <typename T>
T operator()(T x) {
return metal::precise::exp(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
}
};
struct Floor {
template <typename T>
T operator()(T x) {
return metal::floor(x);
};
template <>
int8_t operator()(int8_t x) {
return x;
};
template <>
int16_t operator()(int16_t x) {
return x;
};
template <>
int32_t operator()(int32_t x) {
return x;
};
template <>
int64_t operator()(int64_t x) {
return x;
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
};
struct Log {
template <typename T>
T operator()(T x) {
return metal::precise::log(x);
};
};
struct Log2 {
template <typename T>
T operator()(T x) {
return metal::precise::log2(x);
};
};
struct Log10 {
template <typename T>
T operator()(T x) {
return metal::precise::log10(x);
};
};
struct Log1p {
template <typename T>
T operator()(T x) {
return log1p(x);
};
};
struct LogicalNot {
template <typename T>
T operator()(T x) {
return !x;
};
};
struct Negative {
template <typename T>
T operator()(T x) {
return -x;
};
};
struct Round {
template <typename T>
T operator()(T x) {
return metal::rint(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {metal::rint(x.real), metal::rint(x.imag)};
};
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
};
struct Sign {
template <typename T>
T operator()(T x) {
return (x > T(0)) - (x < T(0));
};
template <>
uint32_t operator()(uint32_t x) {
return x != 0;
};
};
struct Sin {
template <typename T>
T operator()(T x) {
return metal::precise::sin(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)};
};
};
struct Sinh {
template <typename T>
T operator()(T x) {
return metal::precise::sinh(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)};
};
};
struct Square {
template <typename T>
T operator()(T x) {
return x * x;
};
};
struct Sqrt {
template <typename T>
T operator()(T x) {
return metal::precise::sqrt(x);
};
};
struct Rsqrt {
template <typename T>
T operator()(T x) {
return metal::precise::rsqrt(x);
};
};
struct Tan {
template <typename T>
T operator()(T x) {
return metal::precise::tan(x);
};
template <>
complex64_t operator()(complex64_t x) {
float tan_a = metal::precise::tan(x.real);
float tanh_b = metal::precise::tanh(x.imag);
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
};
};
struct Tanh {
template <typename T>
T operator()(T x) {
return metal::precise::tanh(x);
};
template <>
complex64_t operator()(complex64_t x) {
float tanh_a = metal::precise::tanh(x.real);
float tan_b = metal::precise::tan(x.imag);
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
};
};

View File

@@ -1,223 +1,6 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/bf16.h"
struct Abs {
template <typename T> T operator()(T x) { return metal::abs(x); };
template <> uint8_t operator()(uint8_t x) { return x; };
template <> uint16_t operator()(uint16_t x) { return x; };
template <> uint32_t operator()(uint32_t x) { return x; };
template <> uint64_t operator()(uint64_t x) { return x; };
template <> bool operator()(bool x) { return x; };
template <> complex64_t operator()(complex64_t x) {
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
};
};
struct ArcCos {
template <typename T> T operator()(T x) { return metal::precise::acos(x); };
};
struct ArcCosh {
template <typename T> T operator()(T x) { return metal::precise::acosh(x); };
};
struct ArcSin {
template <typename T> T operator()(T x) { return metal::precise::asin(x); };
};
struct ArcSinh {
template <typename T> T operator()(T x) { return metal::precise::asinh(x); };
};
struct ArcTan {
template <typename T> T operator()(T x) { return metal::precise::atan(x); };
};
struct ArcTanh {
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
};
struct Ceil {
template <typename T> T operator()(T x) { return metal::ceil(x); };
template <> int8_t operator()(int8_t x) { return x; };
template <> int16_t operator()(int16_t x) { return x; };
template <> int32_t operator()(int32_t x) { return x; };
template <> int64_t operator()(int64_t x) { return x; };
template <> uint8_t operator()(uint8_t x) { return x; };
template <> uint16_t operator()(uint16_t x) { return x; };
template <> uint32_t operator()(uint32_t x) { return x; };
template <> uint64_t operator()(uint64_t x) { return x; };
template <> bool operator()(bool x) { return x; };
};
struct Cos {
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)
};
};
};
struct Cosh {
template <typename T> T operator()(T x) { return metal::precise::cosh(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)
};
};
};
struct Erf {
template <typename T> T operator()(T x) { return static_cast<T>(erf(static_cast<float>(x))); };
};
struct ErfInv {
template <typename T> T operator()(T x) { return static_cast<T>(erfinv(static_cast<float>(x))); };
};
struct Exp {
template <typename T> T operator()(T x) { return metal::precise::exp(x); };
template <> complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
}
};
struct Floor {
template <typename T> T operator()(T x) { return metal::floor(x); };
template <> int8_t operator()(int8_t x) { return x; };
template <> int16_t operator()(int16_t x) { return x; };
template <> int32_t operator()(int32_t x) { return x; };
template <> int64_t operator()(int64_t x) { return x; };
template <> uint8_t operator()(uint8_t x) { return x; };
template <> uint16_t operator()(uint16_t x) { return x; };
template <> uint32_t operator()(uint32_t x) { return x; };
template <> uint64_t operator()(uint64_t x) { return x; };
template <> bool operator()(bool x) { return x; };
};
struct Log {
template <typename T> T operator()(T x) { return metal::precise::log(x); };
};
struct Log2 {
template <typename T> T operator()(T x) { return metal::precise::log2(x); };
};
struct Log10 {
template <typename T> T operator()(T x) { return metal::precise::log10(x); };
};
struct Log1p {
template <typename T> T operator()(T x) { return log1p(x); };
};
struct LogicalNot {
template <typename T> T operator()(T x) { return !x; };
};
struct Negative {
template <typename T> T operator()(T x) { return -x; };
};
struct Round {
template <typename T> T operator()(T x) { return metal::rint(x); };
template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; };
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
};
struct Sign {
template <typename T> T operator()(T x) { return (x > T(0)) - (x < T(0)); };
template <> uint32_t operator()(uint32_t x) { return x != 0; };
};
struct Sin {
template <typename T> T operator()(T x) { return metal::precise::sin(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)
};
};
};
struct Sinh {
template <typename T> T operator()(T x) { return metal::precise::sinh(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)
};
};
};
struct Square {
template <typename T> T operator()(T x) { return x * x; };
};
struct Sqrt {
template <typename T> T operator()(T x) { return metal::precise::sqrt(x); };
};
struct Rsqrt {
template <typename T> T operator()(T x) { return metal::precise::rsqrt(x); };
};
struct Tan {
template <typename T> T operator()(T x) { return metal::precise::tan(x); };
template <>
complex64_t operator()(complex64_t x) {
float tan_a = metal::precise::tan(x.real);
float tanh_b = metal::precise::tanh(x.imag);
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {
(tan_a - tanh_b * t1) / denom,
(tanh_b + tan_a * t1) / denom
};
};
};
struct Tanh {
template <typename T> T operator()(T x) { return metal::precise::tanh(x); };
template <>
complex64_t operator()(complex64_t x) {
float tanh_a = metal::precise::tanh(x.real);
float tan_b = metal::precise::tan(x.imag);
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {
(tanh_a + tan_b * t1) / denom,
(tan_b - tanh_a * t1) / denom
};
};
};
#include "mlx/backend/metal/kernels/unary.h"
template <typename T, typename Op>
[[kernel]] void unary_op_v(

View File

@@ -12,10 +12,10 @@
template <typename U>
struct Limits {
static const constant U max;
static const constant U min;
static const constant U finite_max;
static const constant U finite_min;
static const constant U max = metal::numeric_limits<U>::max();
static const constant U min = metal::numeric_limits<U>::min();
static const constant U finite_max = metal::numeric_limits<U>::max();
static const constant U finite_min = metal::numeric_limits<U>::min();
};
#define instantiate_default_limit(type) \
@@ -71,7 +71,7 @@ inline size_t elem_to_loc(
device const size_t* strides,
int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
@@ -84,7 +84,7 @@ inline size_t elem_to_loc(
constant const size_t* strides,
int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
@@ -273,4 +273,4 @@ inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
inline bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}
}

View File

@@ -0,0 +1,28 @@
#!/bin/bash
#
# This script generates a C++ function that provides the Metal unary and binary
# ops at runtime for use with kernel generation.
#
# Copyright © 2023-24 Apple Inc.
OUTPUT_FILE=$1
CC=$2
SRCDIR=$3
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
// Copyright © 2023-24 Apple Inc.
namespace mlx::core::metal {
const char* get_kernel_preamble() {
return R"preamble(
$CONTENT
)preamble";
}
} // namespace mlx::core::metal
EOF

View File

@@ -63,7 +63,15 @@ std::function<void()> make_task(
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
auto outputs = arr.outputs();
arr.primitive().eval_gpu(arr.inputs(), outputs);
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.push_back(in.data_shared_ptr());

View File

@@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int bo = std::min(32, O);
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
set_array_buffer(compute_encoder, w, 0);
set_array_buffer(compute_encoder, scales, 1);

View File

@@ -0,0 +1,55 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/utils.h"
#include "mlx/fast.h"
#include "mlx/primitives.h"
namespace mlx::core::fast {
void RoPE::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
auto& out = outputs[0];
if (in.ndim() != 3) {
throw std::runtime_error(
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
}
if (dims_ != in.shape(-1)) {
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
}
if (in.flags().row_contiguous && in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
std::ostringstream kname;
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
bool donated = in.data_shared_ptr() == nullptr;
float base = std::log2(base_);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, donated ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&offset_, sizeof(int), 3);
compute_encoder->setBytes(&base, sizeof(float), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 5);
int dim0 = in.shape(2) / 2;
int dim1 = in.shape(1);
int dim2 = in.shape(0);
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core::fast

View File

@@ -22,7 +22,12 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) {
if (x.strides()[x.ndim() - 1] == 1) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@@ -9,20 +9,6 @@ namespace mlx::core {
namespace {
void set_array_buffer(
MTL::ComputeCommandEncoder* compute_encoder,
MTL::ArgumentEncoder* enc,
const array& a,
int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
enc->setBuffer(a_buf, offset, idx);
// MTL::Resource usage through argument buffer needs to be explicitly
// flagged to enable hazard tracking
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
}
void set_array_buffer(
MTL::ComputeCommandEncoder* enc,
const array& a,
@@ -117,16 +103,18 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>> strides) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (xs[0].ndim() > 0) {
if (shape.size() > 0) {
to_collapse.push_back(0);
for (int i = 1; i < xs[0].ndim(); i++) {
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
for (auto& x : xs) {
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
for (const std::vector<size_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
@@ -142,21 +130,31 @@ collapse_contiguous_dims(const std::vector<array>& xs) {
}
std::vector<int> out_shape;
std::vector<std::vector<size_t>> out_strides(xs.size());
std::vector<std::vector<size_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = xs[0].shape()[to_collapse[i]];
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= xs[0].shape()[to_collapse[i]];
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < xs.size(); j++) {
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
for (int j = 0; j < strides.size(); j++) {
const std::vector<size_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
std::vector<std::vector<size_t>> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
return collapse_contiguous_dims(xs[0].shape(), strides);
}
template <typename... Arrays>
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(Arrays... xs) {

View File

@@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/primitives.h"
#include "mlx/fast.h"
#define NO_GPU_MULTI(func) \
void func::eval_gpu( \
@@ -32,6 +33,7 @@ NO_GPU(AsType)
NO_GPU(AsStrided)
NO_GPU(Broadcast)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Concatenate)
NO_GPU(Convolution)
NO_GPU(Copy)
@@ -40,6 +42,7 @@ NO_GPU(Cosh)
NO_GPU_MULTI(CustomVJP)
NO_GPU_MULTI(Depends)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
@@ -69,6 +72,7 @@ NO_GPU(NotEqual)
NO_GPU(Pad)
NO_GPU(Partition)
NO_GPU(Power)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits)
NO_GPU(Reduce)
@@ -91,6 +95,9 @@ NO_GPU(Subtract)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU_MULTI(DivMod)
NO_GPU_MULTI(QRF)
namespace fast {
NO_GPU_MULTI(RoPE)
} // namespace fast
} // namespace mlx::core

View File

@@ -1,36 +1,168 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include "mlx/allocator.h"
#include "mlx/compile.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core {
namespace detail {
constexpr int max_compile_depth = 10;
bool& compiler_disabled() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
return true;
} else {
return false;
}
};
static bool compiler_disabled_ = get_val();
return compiler_disabled_;
bool is_unary(const Primitive& p) {
return (
typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) ||
typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) ||
typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) ||
typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) ||
typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) ||
typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) ||
typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) ||
typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) ||
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
typeid(p) == typeid(Tanh));
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
bool is_binary(const Primitive& p) {
return (
typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) ||
typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) ||
typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) ||
typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) ||
typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) ||
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
typeid(p) == typeid(Subtract));
}
bool is_broadcast(const Primitive& p) {
return typeid(p) == typeid(Broadcast);
}
bool is_noop(const Primitive& p) {
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
}
bool is_fusable(const Primitive& p) {
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
}
namespace detail {
std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs);
} // namespace detail
Compiled::Compiled(
Stream stream,
std::vector<array> inputs,
std::vector<array> outputs,
std::vector<array> tape,
std::unordered_set<uintptr_t> constant_ids)
: Primitive(stream),
inputs_(std::move(inputs)),
outputs_(std::move(outputs)),
tape_(std::move(tape)),
constant_ids_(std::move(constant_ids)) {}
std::vector<array> Compiled::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
}
std::vector<array> Compiled::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
}
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
}
bool Compiled::is_equivalent(const Primitive& other) const {
const Compiled& a_other = static_cast<const Compiled&>(other);
return std::equal(
tape_.begin(),
tape_.end(),
a_other.tape_.begin(),
a_other.tape_.end(),
[](const array& a1, const array& a2) {
auto& p1 = a1.primitive();
auto& p2 = a2.primitive();
return typeid(p1) == typeid(p2) && p1.is_equivalent(p2);
});
}
void Compiled::print(std::ostream& os) {
os << "Compiled";
for (auto& a : tape_) {
a.primitive().print(os);
}
}
namespace detail {
CompileMode& compile_mode() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
return CompileMode::disabled;
} else {
return CompileMode::enabled;
}
};
static CompileMode compile_mode_ = get_val();
return compile_mode_;
}
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
// Helper that merges two arrays in the graph by setting the parents of the
// source to point to the destination
void merge(array& dst, array& src, ParentsMap& parents_map) {
// Canonicalize the order of the primitives outputs
auto sources = src.outputs();
auto dests = dst.outputs();
// For each src parent, point it to the corresponding dst
for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id());
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}
};
template <typename T, typename... U>
size_t getAddress(std::function<T(U...)> f) {
typedef T(fnType)(U...);
@@ -59,9 +191,7 @@ struct CompilerCache {
auto is_match = [](const std::vector<array>& in1,
const std::vector<array>& in2) {
if (in1.size() != in2.size()) {
throw std::runtime_error(
"[compiler] Got different number of inputs to function,"
" this should never happen.");
return false;
}
for (int i = 0; i < in1.size(); ++i) {
if (in1[i].shape() != in2[i].shape()) {
@@ -205,28 +335,6 @@ void compile_simplify(
}
}
// Helper that fuses two arrays in the graph by setting the parents of the
// source to point to the destination
auto fuse = [&](array& dst, array& src) {
// Canonicalize the order of the primitives outputs
auto sources = src.outputs();
auto dests = dst.outputs();
// For each src parent, point it to the corresponding dest
for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id());
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}
};
// Depth-1 array equivalence check.
auto array_equivalent = [](const array& a, const array& b) {
if (!a.has_primitive() || !b.has_primitive()) {
@@ -254,33 +362,32 @@ void compile_simplify(
return pa.is_equivalent(pb);
};
// Pass 0: fuse scalars
// Merge scalars
std::vector<array> new_tape;
for (auto& arr : tape) {
// Check if we can fuse scalars
// Check if we can merge scalars
if (is_scalar(arr)) {
auto scalar = scalars.find(get_scalar_rep(arr));
if (scalar->second.id() != arr.id()) {
fuse(scalar->second, arr);
merge(scalar->second, arr, parents_map);
// Don't keep orphaned scalars in the tape
continue;
}
}
new_tape.push_back(std::move(arr));
}
tape = std::move(new_tape);
std::unordered_set<uintptr_t> output_set;
for (auto& o : outputs) {
output_set.insert(o.id());
}
// Pass 1..passes: fuse only keeping non-orphaned arrays in the tape
// Multi-pass merge only keeping non-orphaned arrays in the tape
for (int pass = 0; pass < passes; ++pass) {
for (auto& arr : tape) {
// Helper to check if we can fuse the parents of the
// Helper to check if we can merge the parents of the
// given array
auto maybe_fuse_parents = [&](auto& a) {
auto maybe_merge_parents = [&](auto& a) {
auto parents = parents_map.find(a.id());
if (parents != parents_map.end()) {
auto N = parents->second.size();
@@ -296,7 +403,7 @@ void compile_simplify(
auto& src = parents->second[j].first;
auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) {
fuse(dst, src);
merge(dst, src, parents_map);
mask[j] = true;
}
}
@@ -313,9 +420,9 @@ void compile_simplify(
}
};
bool discard = maybe_fuse_parents(arr);
bool discard = maybe_merge_parents(arr);
for (auto& s : arr.siblings()) {
discard &= maybe_fuse_parents(s);
discard &= maybe_merge_parents(s);
}
// If an array and its siblings have no parents, and none of them are
// outputs, it is safe to remove it from the tape
@@ -327,6 +434,216 @@ void compile_simplify(
}
}
// Extract sub-graphs of the graph that can be compiled
// and replace them with a Compiled Primitive.
void compile_fuse(
std::vector<array>& tape,
ParentsMap& parents_map,
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Track outputs to replace with new compiled outputs
std::unordered_map<uintptr_t, array> output_map;
for (auto& o : outputs) {
output_map.insert({o.id(), o});
}
// Set of inputs to distinguish constants
std::unordered_set<uintptr_t> input_ids;
for (auto& in : inputs) {
input_ids.insert(in.id());
}
// Go through the tape in reverse order and check for fusable sub-graphs
std::vector<array> new_tape;
std::unordered_set<uintptr_t> global_cache;
for (int i = tape.size() - 1; i >= 0; --i) {
auto& arr = tape[i];
// Already compiled
if (global_cache.find(arr.id()) != global_cache.end()) {
continue;
}
// Two pass recursion:
// First pass:
// - Collect all the primitives which we can fuse with
// - Keeps a cache of fusable primitives which may be added out of
// DAG order. We have to determine if all of a fused primitive's
// outputs are also in the fused section, and this may not be the
// case the first time we visit it.
// Second pass:
// - Collect inputs to the new compiled primitive
// - Add fusable primitives to a tape in the correct order
std::function<void(const array&, int, const Stream&)> recurse;
std::unordered_set<uintptr_t> cache;
recurse = [&](const array& a, int depth, const Stream& s) {
if (cache.find(a.id()) != cache.end()) {
return;
}
// Stop fusing if:
// - Depth limit exceeded
// - Constant input
// - Stream mismatch
// - Non fusable primitive
if (depth >= max_compile_depth || !a.has_primitive() ||
a.primitive().stream() != s || !is_fusable(a.primitive())) {
return;
}
bool all_parents_in = true;
if (depth > 0) {
// Guaranteed to have a parent since nested in the
// recursion.
auto& parents = parents_map.at(a.id());
for (auto& [p, idx] : parents) {
auto in_cache = cache.find(p.id()) != cache.end();
if (!in_cache) {
all_parents_in = false;
break;
}
}
}
// Arrays with a mix of parents outside the compilable section
// are not fusable
if (!all_parents_in) {
return;
}
cache.insert({a.id()});
for (auto& in : a.inputs()) {
recurse(in, depth + 1, s);
}
};
if (arr.has_primitive()) {
Stream s = arr.primitive().stream();
recurse(arr, 0, s);
}
// Not worth fusing a single primitive
if (cache.size() <= 1) {
new_tape.push_back(arr);
continue;
}
// Recurse a second time to build the tape in the right
// order and collect the inputs
std::unordered_set<uintptr_t> input_set;
std::vector<array> inputs;
std::vector<array> fused_tape;
std::unordered_set<uintptr_t> tape_set;
std::function<void(const array&)> recurse_tape;
recurse_tape = [&](const array& a) {
if (cache.find(a.id()) == cache.end()) {
if (input_set.find(a.id()) == input_set.end()) {
input_set.insert(a.id());
inputs.push_back(a);
}
return;
}
if (tape_set.find(a.id()) != tape_set.end()) {
return;
}
tape_set.insert(a.id());
for (auto& in : a.inputs()) {
recurse_tape(in);
}
fused_tape.push_back(a);
};
recurse_tape(arr);
std::vector<array> old_outputs;
// Add to global cache and add any global outputs to outputs
// of new primitive
for (int j = 0; j < fused_tape.size() - 1; ++j) {
auto& f = fused_tape[j];
if (output_map.find(f.id()) != output_map.end()) {
old_outputs.push_back(f);
// Parents are now siblings, update the parent map
auto& pairs = parents_map[f.id()];
pairs.erase(
std::remove_if(
pairs.begin(),
pairs.end(),
[&](auto& p) {
return cache.find(p.first.id()) != cache.end();
}),
pairs.end());
} else {
// Remove inner fused arrays parents from the parents map
// to keep the parents map in a valid state
parents_map.erase(f.id());
}
global_cache.insert({f.id()});
}
old_outputs.push_back(arr);
std::vector<std::vector<int>> shapes;
std::vector<Dtype> types;
for (auto& o : old_outputs) {
shapes.push_back(o.shape());
types.push_back(o.dtype());
}
std::unordered_set<uintptr_t> constant_ids;
for (auto& in : inputs) {
// Scalar constant
if (in.size() == 1 && !in.has_primitive() &&
input_ids.find(in.id()) == input_ids.end()) {
constant_ids.insert(in.id());
}
}
auto compiled_outputs = array::make_arrays(
shapes,
types,
std::make_shared<Compiled>(
old_outputs.back().primitive().stream(),
inputs,
old_outputs,
std::move(fused_tape),
std::move(constant_ids)),
inputs);
// One output per primitive
new_tape.push_back(compiled_outputs.back());
// Replace inputs old parents with compiled_outputs
for (int i = 0; i < inputs.size(); ++i) {
auto& pairs = parents_map[inputs[i].id()];
pairs.erase(
std::remove_if(
pairs.begin(),
pairs.end(),
[&](auto& p) { return cache.find(p.first.id()) != cache.end(); }),
pairs.end());
for (auto& o : compiled_outputs) {
pairs.push_back({o, i});
}
}
// - Update outputs parents to point to compiled outputs
// - Update any overall graph outputs to be compiled outputs
for (int o = 0; o < old_outputs.size(); ++o) {
merge(compiled_outputs[o], old_outputs[o], parents_map);
if (auto it = output_map.find(old_outputs[o].id());
it != output_map.end()) {
it->second = compiled_outputs[o];
}
}
}
std::reverse(new_tape.begin(), new_tape.end());
tape = std::move(new_tape);
// Replace output with potentially compiled output
for (auto& o : outputs) {
o = output_map.at(o.id());
}
}
std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
@@ -380,10 +697,17 @@ std::vector<array> compile_replace(
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id) {
if (compiler_disabled()) {
if (compile_mode() == CompileMode::disabled) {
return fun;
}
return [fun, fun_id](const std::vector<array>& inputs) {
// If the inputs are tracers, trace the original graph
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
return in.is_tracer();
})) {
return fun(inputs);
}
// Find a cache entry with the correct inputs
auto& entry = compiler_cache().find(fun_id, inputs);
@@ -402,10 +726,16 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
compile_dfs(entry.inputs, entry.outputs);
// Simplify the tape
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3);
if (compile_mode() != CompileMode::no_simplify) {
compile_simplify(
entry.tape, parents_map, entry.outputs, /* passes */ 3);
}
// This is a good point to do more optimizations, e.g. kernel fusion to
// generate new primitives. The tape needs to be updated accordingly
// Kernel fusion to generate Compiled primitives. The tape and
// new outputs must be updated accordingly
if (compile_mode() != CompileMode::no_fuse) {
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
}
}
// At this point we must have a tape, now replace the placeholders
@@ -422,7 +752,7 @@ void compile_erase(size_t fun_id) {
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
if (detail::compiler_disabled()) {
if (detail::compile_mode() == CompileMode::disabled) {
return fun;
}
auto fun_id = detail::getAddress(fun);
@@ -430,11 +760,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
}
void disable_compile() {
detail::compiler_disabled() = true;
detail::compile_mode() = CompileMode::disabled;
}
void enable_compile() {
detail::compiler_disabled() = false;
detail::compile_mode() = CompileMode::enabled;
}
void set_compile_mode(CompileMode mode) {
detail::compile_mode() = mode;
}
} // namespace mlx::core

28
mlx/compile.h Normal file
View File

@@ -0,0 +1,28 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
// Compile takes a function and returns a new function
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
/** Globally disable compilation.
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
* be used to disable compilation.
*/
void disable_compile();
/** Globally enable compilation.
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
*/
void enable_compile();
/** Set the compiler mode to the given value. */
void set_compile_mode(CompileMode mode);
} // namespace mlx::core

128
mlx/fast.cpp Normal file
View File

@@ -0,0 +1,128 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast.h"
#include "mlx/transforms.h"
namespace mlx::core::fast {
std::vector<array> Custom::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
std::vector<array> vjp_outs;
for (int i = 0, j = 0; i < vjps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
vjp_outs.push_back(vjps[i]);
j++;
}
}
return vjp_outs;
}
std::vector<array> Custom::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
std::vector<array> jvp_outs;
for (int i = 0, j = 0; i < jvps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
jvp_outs.push_back(jvps[i]);
j++;
}
}
return jvp_outs;
}
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto outputs = mlx::core::vmap(fallback_, axes)(inputs);
auto out_axes = std::vector<int>(outputs.size(), 0);
return {outputs, out_axes};
}
array rope(
const array& x,
int dims,
bool traditional,
float base,
float scale,
int offset,
StreamOrDevice s /* = {} */) {
if (x.ndim() != 3) {
std::ostringstream msg;
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
if (traditional && x.shape(-1) != dims) {
throw std::invalid_argument(
"[rope] Does not support partial traditional application.");
}
auto fallback = [dims, traditional, base, scale, offset, s](
const std::vector<array>& inputs) {
auto& x = inputs[0];
auto t = x.dtype();
auto N = x.shape(1) + offset;
// Compute sines and cosines
auto half_dims = dims / 2;
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
auto freqs = negative(arange(0, half_dims, t, s), s);
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
auto theta =
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
auto coss = cos(theta, s);
auto sins = sin(theta, s);
if (traditional) {
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s);
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s);
std::vector<array> outs;
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
for (auto& o : outs) {
o = expand_dims(o, 3, s);
}
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)};
} else {
auto out_s = x.shape();
out_s.back() = half_dims;
auto x1 = slice(x, {0, 0, 0}, out_s, s);
out_s.back() = dims;
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
std::vector<array> outs;
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
}
return std::vector<array>{concatenate(outs, 2, s)};
}
};
// TODO change to condition for using custom prim
auto stream = to_stream(s);
if (stream.device == Device::gpu && x.shape(-1) == dims) {
return array(
x.shape(),
x.dtype(),
std::make_unique<RoPE>(
stream, fallback, dims, traditional, base, scale, offset),
{x});
}
return fallback({x})[0];
}
bool RoPE::is_equivalent(const Primitive& other) const {
const RoPE& a_other = static_cast<const RoPE&>(other);
return (
dims_ == a_other.dims_ && base_ == a_other.base_ &&
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
offset_ == a_other.offset_);
}
} // namespace mlx::core::fast

82
mlx/fast.h Normal file
View File

@@ -0,0 +1,82 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/ops.h"
#include "mlx/primitives.h"
namespace mlx::core::fast {
// Custom primitive accepts a fallback function which it uses for
// transformations. Transformations are virtual so that derived classes may to
// override the default behavior
class Custom : public Primitive {
public:
explicit Custom(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback)
: Primitive(stream), fallback_(fallback){};
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
virtual std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
virtual std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
};
array rope(
const array& x,
int dims,
bool traditional,
float base,
float scale,
int offset,
StreamOrDevice s /* = {} */);
class RoPE : public Custom {
public:
RoPE(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
int dims,
bool traditional,
float base,
float scale,
int offset)
: Custom(stream, fallback),
dims_(dims),
traditional_(traditional),
base_(base),
scale_(scale),
offset_(offset){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(RoPE)
bool is_equivalent(const Primitive& other) const override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int dims_;
bool traditional_;
float base_;
float scale_;
int offset_;
};
} // namespace mlx::core::fast

View File

@@ -12,27 +12,24 @@
namespace mlx::core {
struct NodeNamer {
std::unordered_map<std::uintptr_t, std::string> names;
std::string get_name(const array& x) {
auto it = names.find(x.id());
if (it == names.end()) {
// Get the next name in the sequence
// [A, B, ..., Z, AA, AB, ...]
std::vector<char> letters;
auto var_num = names.size() + 1;
while (var_num > 0) {
letters.push_back('A' + (var_num - 1) % 26);
var_num = (var_num - 1) / 26;
}
std::string name(letters.rbegin(), letters.rend());
names.insert({x.id(), name});
return name;
const std::string& NodeNamer::get_name(const array& x) {
auto it = names.find(x.id());
if (it == names.end()) {
// Get the next name in the sequence
// [A, B, ..., Z, AA, AB, ...]
std::vector<char> letters;
auto var_num = names.size() + 1;
while (var_num > 0) {
letters.push_back('A' + (var_num - 1) % 26);
var_num = (var_num - 1) / 26;
}
return it->second;
std::string name(letters.rbegin(), letters.rend());
names.insert({x.id(), name});
return get_name(x);
}
};
return it->second;
}
void depth_first_traversal(
std::function<void(array)> callback,

View File

@@ -6,6 +6,12 @@
namespace mlx::core {
struct NodeNamer {
std::unordered_map<std::uintptr_t, std::string> names;
const std::string& get_name(const array& x);
};
void print_graph(std::ostream& os, const std::vector<array>& outputs);
template <typename... Arrays>

View File

@@ -10,6 +10,14 @@
#include "mlx/stream.h"
namespace mlx::core {
using GGUFMetaData =
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
using GGUFLoad = std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, GGUFMetaData>>;
using SafetensorsLoad = std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string>>;
/** Save array to out stream in .npy format */
void save(std::shared_ptr<io::Writer> out_stream, array a);
@@ -24,32 +32,29 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
array load(const std::string& file, StreamOrDevice s = {});
/** Load array map from .safetensors file format */
std::unordered_map<std::string, array> load_safetensors(
SafetensorsLoad load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s = {});
std::unordered_map<std::string, array> load_safetensors(
SafetensorsLoad load_safetensors(
const std::string& file,
StreamOrDevice s = {});
void save_safetensors(
std::shared_ptr<io::Writer> in_stream,
std::unordered_map<std::string, array>);
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string> metadata = {});
void save_safetensors(
const std::string& file,
std::unordered_map<std::string, array>);
using MetaData =
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
std::unordered_map<std::string, array>,
std::unordered_map<std::string, std::string> metadata = {});
/** Load array map and metadata from .gguf file format */
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, MetaData>>
load_gguf(const std::string& file, StreamOrDevice s = {});
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
void save_gguf(
std::string file,
std::unordered_map<std::string, array> array_map,
std::unordered_map<std::string, MetaData> meta_data = {});
std::unordered_map<std::string, GGUFMetaData> meta_data = {});
} // namespace mlx::core

View File

@@ -11,14 +11,8 @@ MESSAGE(STATUS "Downloading json")
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlx PUBLIC
mlx PRIVATE
$<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
$<INSTALL_INTERFACE:include/json>
)
install(
DIRECTORY ${json_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/json
COMPONENT json_source
)
MESSAGE(STATUS "Downloading gguflib")
@@ -28,14 +22,8 @@ FetchContent_Declare(gguflib
)
FetchContent_MakeAvailable(gguflib)
target_include_directories(
mlx PUBLIC
mlx PRIVATE
$<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/gguflib>
)
install(
DIRECTORY ${gguflib_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/gguflib
COMPONENT gguflib_source
)
add_library(

View File

@@ -82,7 +82,7 @@ void set_mx_value_from_gguf(
gguf_ctx* ctx,
uint32_t type,
gguf_value* val,
MetaData& value) {
GGUFMetaData& value) {
switch (type) {
case GGUF_VALUE_TYPE_UINT8:
value = array(val->uint8, uint8);
@@ -191,12 +191,12 @@ void set_mx_value_from_gguf(
}
}
std::unordered_map<std::string, MetaData> load_metadata(gguf_ctx* ctx) {
std::unordered_map<std::string, MetaData> metadata;
std::unordered_map<std::string, GGUFMetaData> load_metadata(gguf_ctx* ctx) {
std::unordered_map<std::string, GGUFMetaData> metadata;
gguf_key key;
while (gguf_get_key(ctx, &key)) {
std::string key_name = std::string(key.name, key.namelen);
auto& val = metadata.insert({key_name, MetaData{}}).first->second;
auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second;
set_mx_value_from_gguf(ctx, key.type, key.val, val);
}
return metadata;
@@ -230,10 +230,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
return array_map;
}
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, MetaData>>
load_gguf(const std::string& file, StreamOrDevice s) {
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
gguf_ctx* ctx = gguf_open(file.c_str());
if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed");
@@ -280,7 +277,7 @@ void append_kv_array(
void save_gguf(
std::string file,
std::unordered_map<std::string, array> array_map,
std::unordered_map<std::string, MetaData> metadata /* = {} */) {
std::unordered_map<std::string, GGUFMetaData> metadata /* = {} */) {
// Add .gguf to file name if it is not there
if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") {
file += ".gguf";

View File

@@ -93,7 +93,7 @@ Dtype dtype_from_safetensor_str(std::string str) {
}
/** Load array from reader in safetensor format */
std::unordered_map<std::string, array> load_safetensors(
SafetensorsLoad load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s) {
////////////////////////////////////////////////////////
@@ -121,9 +121,12 @@ std::unordered_map<std::string, array> load_safetensors(
size_t offset = jsonHeaderLength + 8;
// Load the arrays using metadata
std::unordered_map<std::string, array> res;
std::unordered_map<std::string, std::string> metadata_map;
for (const auto& item : metadata.items()) {
if (item.key() == "__metadata__") {
// ignore metadata for now
for (const auto& meta_item : item.value().items()) {
metadata_map.insert({meta_item.key(), meta_item.value()});
}
continue;
}
std::string dtype = item.value().at("dtype");
@@ -138,19 +141,18 @@ std::unordered_map<std::string, array> load_safetensors(
std::vector<array>{});
res.insert({item.key(), loaded_array});
}
return res;
return {res, metadata_map};
}
std::unordered_map<std::string, array> load_safetensors(
const std::string& file,
StreamOrDevice s) {
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
return load_safetensors(std::make_shared<io::FileReader>(file), s);
}
/** Save array to out stream in .npy format */
void save_safetensors(
std::shared_ptr<io::Writer> out_stream,
std::unordered_map<std::string, array> a) {
std::unordered_map<std::string, array> a,
std::unordered_map<std::string, std::string> metadata /* = {} */) {
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {
@@ -161,9 +163,11 @@ void save_safetensors(
////////////////////////////////////////////////////////
// Check array map
json parent;
parent["__metadata__"] = json::object({
{"format", "mlx"},
});
json _metadata;
for (auto& [key, value] : metadata) {
_metadata[key] = value;
}
parent["__metadata__"] = _metadata;
size_t offset = 0;
for (auto& [key, arr] : a) {
arr.eval();
@@ -204,7 +208,8 @@ void save_safetensors(
void save_safetensors(
const std::string& file_,
std::unordered_map<std::string, array> a) {
std::unordered_map<std::string, array> a,
std::unordered_map<std::string, std::string> metadata /* = {} */) {
// Open and check file
std::string file = file_;
@@ -214,7 +219,7 @@ void save_safetensors(
file += ".safetensors";
// Serialize array
save_safetensors(std::make_shared<io::FileWriter>(file), a);
save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata);
}
} // namespace mlx::core

View File

@@ -4,7 +4,9 @@
#include "mlx/array.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/compile.h"
#include "mlx/device.h"
#include "mlx/fast.h"
#include "mlx/fft.h"
#include "mlx/io.h"
#include "mlx/linalg.h"

View File

@@ -59,16 +59,6 @@ Dtype at_least_float(const Dtype& d) {
} // namespace
Stream to_stream(StreamOrDevice s) {
if (std::holds_alternative<std::monostate>(s)) {
return default_stream(default_device());
} else if (std::holds_alternative<Device>(s)) {
return default_stream(std::get<Device>(s));
} else {
return std::get<Stream>(s);
}
}
array arange(
double start,
double stop,
@@ -148,6 +138,9 @@ array linspace(
msg << "[linspace] number of samples, " << num << ", must be non-negative.";
throw std::invalid_argument(msg.str());
}
if (num == 1) {
return astype(array({start}), dtype, to_stream(s));
}
array sequence = arange(0, num, float32, to_stream(s));
float step = (stop - start) / (num - 1);
return astype(
@@ -629,6 +622,13 @@ std::vector<array> split(
std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) {
auto ax = axis < 0 ? axis + a.ndim() : axis;
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "Invalid axis " << axis << " passed to split"
<< " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
auto q_and_r = std::ldiv(a.shape(axis), num_splits);
if (q_and_r.rem) {
std::ostringstream msg;
@@ -1316,6 +1316,15 @@ array mean(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
int ndim = a.ndim();
for (int axis : axes) {
if (axis < -ndim || axis >= ndim) {
std::ostringstream msg;
msg << "[mean] axis " << axis << " is out of bounds for array with "
<< ndim << " dimensions.";
throw std::invalid_argument(msg.str());
}
}
auto nelements = compute_number_of_elements(a, axes);
auto dtype = at_least_float(a.dtype());
return multiply(sum(a, axes, keepdims, s), array(1.0 / nelements, dtype), s);
@@ -1352,7 +1361,7 @@ array var(
if (ddof != 0) {
auto nelements = compute_number_of_elements(a, axes);
float factor = nelements / (nelements - ddof);
auto factor = nelements / static_cast<float>(std::max(nelements - ddof, 0));
v = multiply(v, array(factor, dtype), s);
}
@@ -1779,10 +1788,6 @@ array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
inputs);
}
array operator&&(const array& a, const array& b) {
// check if a and b are bool arrays
if (a.dtype() != bool_ || b.dtype() != bool_) {
throw std::invalid_argument("[operator&&] only supported for bool arrays.");
}
return logical_and(a, b);
}
@@ -1797,11 +1802,6 @@ array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
inputs);
}
array operator||(const array& a, const array& b) {
// check if a and b are bool arrays
if (a.dtype() != bool_ || b.dtype() != bool_) {
throw std::invalid_argument(
"[operator||] is only supported for bool arrays.");
}
return logical_or(a, b);
}
@@ -2830,7 +2830,7 @@ array conv2d(
throw std::invalid_argument("[conv2d] Cannot handle groups != 1 yet");
}
if (dilation.first != 1 || dilation.second != 1) {
throw std::invalid_argument("[conv1d] Cannot handle dilation != 1 yet");
throw std::invalid_argument("[conv2d] Cannot handle dilation != 1 yet");
}
// Run checks

View File

@@ -3,18 +3,14 @@
#pragma once
#include <optional>
#include <variant>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/stream.h"
#include "mlx/utils.h"
namespace mlx::core {
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
Stream to_stream(StreamOrDevice s);
/** Creation operations */
/**

View File

@@ -2,6 +2,8 @@
#pragma once
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/io/load.h"
@@ -451,6 +453,53 @@ class Ceil : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Compiled : public Primitive {
public:
/*
* The inputs, outputs and tape are either tracers or constants.
* - The tape should not contain the inputs, but it should contain the
* outputs.
* - The tape should also have only one array per primitive for multi-output
* primitives.
* - The constant_ids contains ids of arrays in the input list that are safe
* to treat as scalar constants.
*/
explicit Compiled(
Stream stream,
std::vector<array> inputs,
std::vector<array> outputs,
std::vector<array> tape,
std::unordered_set<uintptr_t> constant_ids);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_VMAP()
DEFINE_GRADS()
void print(std::ostream& os) override;
bool is_equivalent(const Primitive& other) const override;
std::string metal_lib_name() const {
return kernel_lib_;
}
std::string metal_lib_source() const {
return kernel_source_;
}
private:
const std::vector<array> inputs_;
const std::vector<array> outputs_;
const std::vector<array> tape_;
const std::unordered_set<uintptr_t> constant_ids_;
std::string kernel_lib_;
std::string kernel_source_;
void eval(const std::vector<array>& inputs, std::vector<array>& out);
};
class Concatenate : public UnaryPrimitive {
public:
explicit Concatenate(Stream stream, int axis)
@@ -667,9 +716,16 @@ class Equal : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Equal)
DEFINE_DEFAULT_IS_EQUIVALENT()
void print(std::ostream& os) override {
if (equal_nan_) {
os << "NanEqual";
} else {
os << "Equal";
}
}
private:
void eval(const std::vector<array>& inputs, array& out);
bool equal_nan_;
@@ -903,9 +959,22 @@ class Log : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Log)
DEFINE_DEFAULT_IS_EQUIVALENT()
void print(std::ostream& os) override {
switch (base_) {
case e:
os << "Log";
break;
case two:
os << "Log2";
break;
case ten:
os << "Log10";
break;
}
}
private:
Base base_;
void eval(const std::vector<array>& inputs, array& out);
@@ -1552,9 +1621,16 @@ class Sqrt : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Sqrt)
bool is_equivalent(const Primitive& other) const override;
void print(std::ostream& os) override {
if (recip_) {
os << "Rsqrt";
} else {
os << "Sqrt";
}
}
private:
void eval(const std::vector<array>& inputs, array& out);
bool recip_;

View File

@@ -153,14 +153,23 @@ array uniform(
array normal(
const std::vector<int>& shape,
Dtype dtype,
const float loc /* = 0.0 */,
const float scale /* = 1.0 */,
const std::optional<array>& key /*= nullopt */,
StreamOrDevice s /* = {} */) {
auto stream = to_stream(s);
auto low = array(std::nextafter(-1.0f, 0.0f), dtype);
auto high = array(1.0f, dtype);
auto samples = uniform(low, high, shape, dtype, key, stream);
return multiply(
array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
samples =
multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
if (scale != 1.0) {
samples = multiply(array(scale, dtype), samples, stream);
}
if (loc != 0.0) {
samples = add(array(loc, dtype), samples, stream);
}
return samples;
}
array randint(

View File

@@ -95,13 +95,30 @@ inline array uniform(
array normal(
const std::vector<int>& shape,
Dtype dtype,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {});
inline array normal(
const std::vector<int>& shape,
const float loc,
const float scale,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, float32, key, s);
return normal(shape, float32, loc, scale, key, s);
}
inline array normal(
const std::vector<int>& shape,
const Dtype dtype,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, dtype, 0.0, 1.0, key, s);
}
inline array normal(
const std::vector<int>& shape,
const std::optional<array>& key = std::nullopt,
StreamOrDevice s = {}) {
return normal(shape, float32, 0.0, 1.0, key, s);
}
/** Generate integer samples uniformly at random */

View File

@@ -6,21 +6,6 @@
namespace mlx::core {
// Compile takes a function and returns a new function
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
/** Globally disable compilation.
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
* be used to disable compilation.
*/
void disable_compile();
/** Globally enable compilation.
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
*/
void enable_compile();
void eval(const std::vector<array>& outputs);
template <typename... Arrays>

View File

@@ -35,6 +35,16 @@ inline bool operator>(const complex64_t& a, const complex64_t& b) {
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
}
inline complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real() - (b.real() * static_cast<int64_t>(a.real() / b.real()));
auto imag = a.imag() - (b.imag() * static_cast<int64_t>(a.imag() / b.imag()));
if (real != 0 && (real < 0 != b.real() < 0))
real += b.real();
if (imag != 0 && (imag < 0 != b.imag() < 0))
imag += b.imag();
return {real, imag};
}
inline bool operator<=(const complex64_t& a, const complex64_t& b) {
return operator>=(b, a);
}

View File

@@ -7,6 +7,16 @@
namespace mlx::core {
Stream to_stream(StreamOrDevice s) {
if (std::holds_alternative<std::monostate>(s)) {
return default_stream(default_device());
} else if (std::holds_alternative<Device>(s)) {
return default_stream(std::get<Device>(s));
} else {
return std::get<Stream>(s);
}
}
void PrintFormatter::print(std::ostream& os, bool val) {
if (capitalize_bool) {
os << (val ? "True" : "False");

View File

@@ -2,6 +2,8 @@
#pragma once
#include <variant>
#include "array.h"
#include "device.h"
#include "dtype.h"
@@ -9,6 +11,30 @@
namespace mlx::core {
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
Stream to_stream(StreamOrDevice s);
struct StreamContext {
public:
StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) {
if (std::holds_alternative<std::monostate>(s)) {
throw std::runtime_error(
"[StreamContext] Invalid argument, please specify a stream or device.");
}
auto _s = to_stream(s);
set_default_device(_s.device);
set_default_stream(_s);
}
~StreamContext() {
set_default_device(_stream.device);
set_default_stream(_stream);
}
private:
Stream _stream;
};
struct PrintFormatter {
inline void print(std::ostream& os, bool val);
inline void print(std::ostream& os, int16_t val);

View File

@@ -60,7 +60,7 @@ def normal(
"""
def initializer(a: mx.array) -> mx.array:
return std * mx.random.normal(shape=a.shape, dtype=dtype) + mean
return mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)
return initializer
@@ -184,7 +184,7 @@ def glorot_normal(
def initializer(a: mx.array, gain: float = 1.0) -> mx.array:
fan_in, fan_out = _calculate_fan_in_fan_out(a)
std = gain * math.sqrt(2.0 / (fan_in + fan_out))
return mx.random.normal(shape=a.shape, dtype=dtype) * std
return mx.random.normal(shape=a.shape, scale=std, dtype=dtype)
return initializer
@@ -285,7 +285,7 @@ def he_normal(
raise ValueError(f"Invalid mode: {mode}. Valid modes are: fan_in, fan_out")
std = gain / math.sqrt(fan)
return mx.random.normal(shape=a.shape, dtype=dtype) * std
return mx.random.normal(shape=a.shape, scale=std, dtype=dtype)
return initializer

View File

@@ -58,6 +58,7 @@ from mlx.nn.layers.normalization import (
LayerNorm,
RMSNorm,
)
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.transformer import (

View File

@@ -66,6 +66,19 @@ class Module(dict):
"""Boolean indicating if the model is in training mode."""
return self._training
@property
def state(self):
"""The module's state dictionary
The module's state dictionary contains any attribute set on the
module including parameters in :meth:`Module.parameters`
Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is
a reference to the module's state. Updates to it will be reflected in
the original module.
"""
return self
def _extra_repr(self):
return ""
@@ -312,7 +325,7 @@ class Module(dict):
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif isinstance(parameters, list):
for i in range(len(dst)):
for i in range(len(parameters)):
current_value = dst[i]
new_value = parameters[i]
if isinstance(current_value, mx.array):

View File

@@ -21,7 +21,7 @@ class Embedding(Module):
def __init__(self, num_embeddings: int, dims: int):
super().__init__()
scale = math.sqrt(1 / dims)
self.weight = mx.random.normal((num_embeddings, dims)) * scale
self.weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
def _extra_repr(self):
return f"{self.weight.shape[0]}, {self.weight.shape[1]}"

View File

@@ -0,0 +1,308 @@
# Copyright © 2023-2024 Apple Inc.
import operator
from itertools import accumulate
from typing import Optional, Tuple, Union
import mlx.core as mx
from mlx.nn.layers.base import Module
def _value_or_list(x, n, msg):
if isinstance(x, (list, tuple)):
if len(x) != n:
raise ValueError(msg)
return list(x)
if not isinstance(x, int):
raise ValueError(msg)
return [x] * n
def _sliding_windows(x, window_shape, window_strides):
if x.ndim < 3:
raise ValueError(
f"To extract sliding windows at least 1 spatial dimension "
f"(3 total) is needed but the input only has {x.ndim} dimensions."
)
spatial_dims = x.shape[1:-1]
if not (len(spatial_dims) == len(window_shape) == len(window_strides)):
raise ValueError(
f"To extract sliding windows the window shapes and strides must have "
f"the same number of spatial dimensions as the signal but the signal "
f"has {len(spatial_dims)} dims and the window shape has {len(window_shape)} "
f"and strides have {len(window_strides)}."
)
shape = x.shape
strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:]
# Compute the output shape
final_shape = [shape[0]]
final_shape += [
(size - window) // stride + 1
for size, window, stride in zip(spatial_dims, window_shape, window_strides)
]
final_shape += window_shape
final_shape += [shape[-1]]
# Compute the output strides
final_strides = strides[:1]
final_strides += [
og_stride * stride for og_stride, stride in zip(strides[1:-1], window_strides)
]
final_strides += strides[1:-1]
final_strides += strides[-1:] # should always be [1]
return mx.as_strided(x, final_shape, final_strides)
class _Pool(Module):
def __init__(self, pooling_function, kernel_size, stride, padding, padding_value):
super().__init__()
self._pooling_function = pooling_function
self._kernel_size = kernel_size
self._stride = stride
self._padding = padding
self._padding_value = padding_value
self._axes = tuple(range(-len(self._kernel_size) - 1, -1, 1))
def _extra_repr(self):
ks = tuple(self._kernel_size)
st = tuple(self._stride)
pd = tuple(p[0] for p in self._padding)
return f"kernel_size={ks}, stride={st}, padding={pd}"
def __call__(self, x):
if any(p[0] > 0 for p in self._padding):
x = mx.pad(x, [(0, 0)] + self._padding + [(0, 0)], self._padding_value)
x = _sliding_windows(x, self._kernel_size, self._stride)
return self._pooling_function(x, self._axes)
class _Pool1d(_Pool):
def __init__(
self,
pooling_function,
padding_value,
kernel_size: Union[int, Tuple[int]],
stride: Optional[Union[int, Tuple[int]]] = None,
padding: Union[int, Tuple[int]] = 0,
):
class_name = type(self).__name__
msg = "[{}] '{}' must be an integer or a tuple containing 1 integer"
kernel_size = _value_or_list(
kernel_size, 1, msg.format(class_name, "kernel_size")
)
if stride is not None:
stride = _value_or_list(stride, 1, msg.format(class_name, "stride"))
else:
stride = kernel_size
padding = _value_or_list(padding, 1, msg.format(class_name, "padding"))
padding = [(p, p) for p in padding]
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
class _Pool2d(_Pool):
def __init__(
self,
pooling_function,
padding_value,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0,
):
class_name = type(self).__name__
msg = "[{}] '{}' must be an integer or a tuple containing 2 integers"
kernel_size = _value_or_list(
kernel_size, 2, msg.format(class_name, "kernel_size")
)
if stride is not None:
stride = _value_or_list(stride, 2, msg.format(class_name, "stride"))
else:
stride = kernel_size
padding = _value_or_list(padding, 2, msg.format(class_name, "padding"))
padding = [(p, p) for p in padding]
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
class MaxPool1d(_Pool1d):
r"""Applies 1-dimensional max pooling.
Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is
:math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given
by:
.. math::
\text{out}(N_i, t, C_j) = \max_{m=0, \ldots, k - 1}
\text{input}(N_i, \text{stride} \times t + m, C_j),
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
\text{kernel_size}}{\text{stride}}\right\rfloor + 1`.
Args:
kernel_size (int or tuple(int)): The size of the pooling window kernel.
stride (int or tuple(int), optional): The stride of the pooling window.
Default: ``kernel_size``.
padding (int or tuple(int), optional): How much negative infinity
padding to apply to the input. The padding amount is applied to
both sides of the spatial axis. Default: ``0``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(4, 16, 5))
>>> pool = nn.MaxPool1d(kernel_size=2, stride=2)
>>> pool(x)
"""
def __init__(
self,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0,
):
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
class AvgPool1d(_Pool1d):
r"""Applies 1-dimensional average pooling.
Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is
:math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given
by:
.. math::
\text{out}(N_i, t, C_j) = \frac{1}{k} \sum_{m=0, \ldots, k - 1}
\text{input}(N_i, \text{stride} \times t + m, C_j),
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
\text{kernel_size}}{\text{stride}}\right\rfloor + 1`.
Args:
kernel_size (int or tuple(int)): The size of the pooling window kernel.
stride (int or tuple(int), optional): The stride of the pooling window.
Default: ``kernel_size``.
padding (int or tuple(int), optional): How much zero padding to apply to
the input. The padding amount is applied to both sides of the spatial
axis. Default: ``0``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(4, 16, 5))
>>> pool = nn.AvgPool1d(kernel_size=2, stride=2)
>>> pool(x)
"""
def __init__(
self,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0,
):
super().__init__(mx.mean, 0, kernel_size, stride, padding)
class MaxPool2d(_Pool2d):
r"""Applies 2-dimensional max pooling.
Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is
:math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out},
W_{out}, C)`, given by:
.. math::
\begin{aligned}
\text{out}(N_i, h, w, C_j) = & \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
& \text{input}(N_i, \text{stride[0]} \times h + m,
\text{stride[1]} \times w + n, C_j),
\end{aligned}
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
- a single ``int`` -- in which case the same value is used for both the
height and width axis;
- a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
used for the height axis, the second ``int`` for the width axis.
Args:
kernel_size (int or tuple(int, int)): The size of the pooling window.
stride (int or tuple(int, int), optional): The stride of the pooling
window. Default: ``kernel_size``.
padding (int or tuple(int, int), optional): How much negative infinity
padding to apply to the input. The padding is applied on both sides
of the height and width axis. Default: ``0``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
>>> pool = nn.MaxPool2d(kernel_size=2, stride=2)
>>> pool(x)
"""
def __init__(
self,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0,
):
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
class AvgPool2d(_Pool2d):
r"""Applies 2-dimensional average pooling.
Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is
:math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out},
W_{out}, C)`, given by:
.. math::
\begin{aligned}
\text{out}(N_i, h, w, C_j) = & \frac{1}{k_H k_W} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
& \text{input}(N_i, \text{stride[0]} \times h + m,
\text{stride[1]} \times w + n, C_j),
\end{aligned}
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
- a single ``int`` -- in which case the same value is used for both the
height and width axis;
- a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
used for the height axis, the second ``int`` for the width axis.
Args:
kernel_size (int or tuple(int, int)): The size of the pooling window.
stride (int or tuple(int, int), optional): The stride of the pooling
window. Default: ``kernel_size``.
padding (int or tuple(int, int), optional): How much zero
padding to apply to the input. The padding is applied on both sides
of the height and width axis. Default: ``0``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn.layers as nn
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
>>> pool = nn.MaxPool2d(kernel_size=2, stride=2)
>>> pool(x)
"""
def __init__(
self,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Optional[Union[int, Tuple[int, int]]] = 0,
):
super().__init__(mx.mean, 0, kernel_size, stride, padding)

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Optional
@@ -20,20 +20,13 @@ class RoPE(Module):
Args:
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool, optional): If set to True choose the traditional
traditional (bool, optional): If set to ``True`` choose the traditional
implementation which is slightly less efficient. Default: ``False``.
base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Default: ``10000``.
scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
Attributes:
_cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values.
_cos_sin_theta_value (tuple): Cached cosine and sine values.
"""
_cos_sin_theta_key = None
_cos_sin_theta_value = None
def __init__(
self,
dims: int,
@@ -50,69 +43,18 @@ class RoPE(Module):
def _extra_repr(self):
return f"{self.dims}, traditional={self.traditional}"
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = mx.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype
x = mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=self.base,
scale=self.scale,
offset=offset,
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
@classmethod
def create_cos_sin_theta(
cls,
N: int,
D: int,
offset: int = 0,
base: float = 10000,
scale: float = 1.0,
dtype=mx.float32,
):
if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
half_D = D // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
)
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))
return cls._cos_sin_theta_value
return mx.reshape(x, shape)
class SinusoidalPositionalEncoding(Module):

View File

@@ -117,6 +117,7 @@ def cross_entropy(
def binary_cross_entropy(
inputs: mx.array,
targets: mx.array,
weights: mx.array = None,
with_logits: bool = True,
reduction: Reduction = "mean",
) -> mx.array:
@@ -128,6 +129,7 @@ def binary_cross_entropy(
``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities.
targets (array): The binary target values in {0, 1}.
with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``.
weights (array, optional): Optional weights for each target. Default: ``None``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
@@ -159,6 +161,15 @@ def binary_cross_entropy(
else:
loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs))
# Apply weights if provided
if weights is not None:
if weights.shape != loss.shape:
raise ValueError(
f"Weights with shape {weights.shape} is not the same as "
f"output loss with shape {loss.shape}."
)
loss *= weights
return _reduce(loss, reduction)
@@ -533,3 +544,58 @@ def cosine_similarity_loss(
loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps)
return _reduce(loss, reduction)
def margin_ranking_loss(
inputs1: mx.array,
inputs2: mx.array,
targets: mx.array,
margin: float = 0.0,
reduction: Reduction = "none",
) -> mx.array:
r"""
Calculate the margin ranking loss that loss given inputs :math:`x_1`, :math:`x_2` and a label
:math:`y` (containing 1 or -1).
The loss is given by:
.. math::
\text{loss} = \max (0, -y * (x_1 - x_2) + \text{margin})
Where :math:`y` represents ``targets``, :math:`x_1` represents ``inputs1`` and :math:`x_2`
represents ``inputs2``.
Args:
inputs1 (array): Scores for the first input.
inputs2 (array): Scores for the second input.
targets (array): Labels indicating whether samples in ``inputs1`` should be ranked higher
than samples in ``inputs2``. Values should be 1 or -1.
margin (float, optional): The margin by which the scores should be separated.
Default: ``0.0``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The computed margin ranking loss.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> targets = mx.array([1, 1, -1])
>>> inputs1 = mx.array([-0.573409, -0.765166, -0.0638])
>>> inputs2 = mx.array([0.75596, 0.225763, 0.256995])
>>> loss = nn.losses.margin_ranking_loss(inputs1, inputs2, targets)
>>> loss
array(0.773433, dtype=float32)
"""
if not (inputs1.shape == inputs2.shape == targets.shape):
raise ValueError(
f"The shapes of the arguments do not match. The provided shapes are "
f"inputs1.shape={inputs1.shape}, inputs2.shape={inputs2.shape}, and "
f"targets.shape={targets.shape}."
)
differences = inputs1 - inputs2
loss = mx.maximum(0, -targets * differences + margin)
return _reduce(loss, reduction)

Some files were not shown because too many files have changed in this diff Show More