mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-12 15:24:57 +08:00
Compare commits
50 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
bf7cd29970 | ||
![]() |
a000d2288c | ||
![]() |
165abf0e4c | ||
![]() |
818cda16bc | ||
![]() |
85143fecdd | ||
![]() |
35431a4ac8 | ||
![]() |
ccf1645995 | ||
![]() |
1a48713d32 | ||
![]() |
1eb04aa23f | ||
![]() |
0c65517e91 | ||
![]() |
2fdc2462c3 | ||
![]() |
be6e9d6a9f | ||
![]() |
e54cbb7ba6 | ||
![]() |
40c108766b | ||
![]() |
4cc70290f7 | ||
![]() |
74caa68d02 | ||
![]() |
3756381358 | ||
![]() |
d12573daa6 | ||
![]() |
0dbc4c7547 | ||
![]() |
06072601ce | ||
![]() |
11d2c8f7a1 | ||
![]() |
7f3f8d8f8d | ||
![]() |
b96be943dc | ||
![]() |
b670485185 | ||
![]() |
b57bd0488d | ||
![]() |
221f8d3fc2 | ||
![]() |
5c03efaf29 | ||
![]() |
7dccd42133 | ||
![]() |
1b97b2958b | ||
![]() |
e5e816a5ef | ||
![]() |
28eac18571 | ||
![]() |
5fd11c347d | ||
![]() |
ef73393a19 | ||
![]() |
ea406d5e33 | ||
![]() |
146bd69470 | ||
![]() |
316ff490b3 | ||
![]() |
d40a04f8dc | ||
![]() |
d75ae52ecd | ||
![]() |
31fea3758e | ||
![]() |
e319383ef9 | ||
![]() |
5c3ac52dd7 | ||
![]() |
ebfd3618b0 | ||
![]() |
11a9fd40f0 | ||
![]() |
4fd2fb84a6 | ||
![]() |
9852af1a19 | ||
![]() |
16750f3c51 | ||
![]() |
95b5fb8245 | ||
![]() |
83f63f2184 | ||
![]() |
cb6156d35d | ||
![]() |
506d43035c |
@@ -1,5 +1,8 @@
|
||||
version: 2.1
|
||||
|
||||
orbs:
|
||||
apple: ml-explore/pr-approval@0.1.0
|
||||
|
||||
parameters:
|
||||
nightly_build:
|
||||
type: boolean
|
||||
@@ -7,6 +10,9 @@ parameters:
|
||||
weekly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
linux_build_and_test:
|
||||
@@ -57,17 +63,18 @@ jobs:
|
||||
command: ./build/tests/tests
|
||||
|
||||
mac_build_and_test:
|
||||
machine: true
|
||||
resource_class: ml-explore/m-builder
|
||||
macos:
|
||||
xcode: "15.2.0"
|
||||
resource_class: macos.m1.large.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
rm -r $CONDA_PREFIX/envs/runner-env
|
||||
conda create -y -n runner-env python=3.9
|
||||
conda activate runner-env
|
||||
brew install python@3.9
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install pybind11-stubgen
|
||||
@@ -78,203 +85,158 @@ jobs:
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
source env/bin/activate
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
# TODO: Reenable when extension api becomes stable
|
||||
# - run:
|
||||
# name: Build example extension
|
||||
# command: |
|
||||
# eval "$(conda shell.bash hook)"
|
||||
# conda activate runner-env
|
||||
# cd examples/extensions && python -m pip install .
|
||||
# cd examples/extensions && python3.11 -m pip install .
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source env/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
command: |
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
DEVICE=cpu ./build/tests/tests
|
||||
|
||||
build_release:
|
||||
machine: true
|
||||
resource_class: ml-explore/m-builder
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
macos_version:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "14"
|
||||
default: "15.2.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.large.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
rm -r $CONDA_PREFIX/envs/runner-env
|
||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
||||
conda activate runner-env
|
||||
brew install python@<< parameters.python_version >>
|
||||
python<< parameters.python_version >> -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install --upgrade setuptools
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
pip install twine
|
||||
# TODO: Update build system to switch away from setup.py develop
|
||||
pip install build
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
||||
PYPI_RELEASE=1 \
|
||||
source env/bin/activate
|
||||
DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python setup.py develop
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
source env/bin/activate
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Publish Python package
|
||||
name: Build Python package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
||||
PYPI_RELEASE=1 \
|
||||
source env/bin/activate
|
||||
<< parameters.build_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python setup.py bdist_wheel
|
||||
twine upload dist/* --repository mlx
|
||||
python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload dist/*
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
|
||||
build_dev_release:
|
||||
machine: true
|
||||
resource_class: ml-explore/m-builder
|
||||
build_linux_test_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
macos_version:
|
||||
extra_env:
|
||||
type: string
|
||||
default: "14"
|
||||
default: "DEV_RELEASE=1"
|
||||
docker:
|
||||
- image: ubuntu:20.04
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
name: Build wheel
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
rm -r $CONDA_PREFIX/envs/runner-env
|
||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
||||
conda activate runner-env
|
||||
PYTHON=python<< parameters.python_version >>
|
||||
apt-get update
|
||||
apt-get upgrade -y
|
||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
||||
apt-get install -y apt-utils
|
||||
apt-get install -y software-properties-common
|
||||
add-apt-repository -y ppa:deadsnakes/ppa
|
||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
apt-get install -y build-essential git
|
||||
$PYTHON -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install --upgrade setuptools
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
pip install twine
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
||||
DEV_RELEASE=1 \
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python setup.py develop
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
pip install . -v
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Publish Python package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
||||
DEV_RELEASE=1 \
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python setup.py bdist_wheel
|
||||
twine upload dist/* --repository mlx
|
||||
python -m build --wheel
|
||||
auditwheel show dist/*
|
||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
|
||||
build_package:
|
||||
machine: true
|
||||
resource_class: ml-explore/m-builder
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
macos_version:
|
||||
type: string
|
||||
default: "14"
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
rm -r $CONDA_PREFIX/envs/runner-env
|
||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
||||
conda activate runner-env
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
pip install twine
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python setup.py develop
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build package distribution
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python setup.py bdist_wheel
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
path: wheelhouse/
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
and:
|
||||
- matches:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- linux_build_and_test
|
||||
- mac_build_and_test
|
||||
- linux_build_and_test
|
||||
- build_release:
|
||||
filters:
|
||||
tags:
|
||||
@@ -284,20 +246,53 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
xcode_version: ["14.3.1", "15.2.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
prb:
|
||||
when:
|
||||
matches:
|
||||
pattern: "^pull/\\d+(/head)?$"
|
||||
value: << pipeline.git.branch >>
|
||||
jobs:
|
||||
- hold:
|
||||
type: approval
|
||||
- apple/authenticate:
|
||||
context: pr-approval
|
||||
- mac_build_and_test:
|
||||
requires: [ hold ]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
when: << pipeline.parameters.nightly_build >>
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.nightly_build >>
|
||||
jobs:
|
||||
- build_package:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
xcode_version: ["14.3.1", "15.2.0"]
|
||||
weekly_build:
|
||||
when: << pipeline.parameters.weekly_build >>
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.weekly_build >>
|
||||
jobs:
|
||||
- build_dev_release:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
xcode_version: ["14.3.1", "15.2.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_linux_test_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
@@ -5,7 +5,7 @@ repos:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 23.12.1
|
||||
rev: 24.2.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
|
@@ -10,8 +10,8 @@ MLX was developed with contributions from the following individuals:
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.1.0)
|
||||
set(MLX_VERSION 0.3.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@@ -123,8 +123,8 @@ else()
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
$ENV{BLAS_HOME}/include)
|
||||
message(STATUS "Blas lib" ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas incclude" ${BLAS_INCLUDE_DIRS})
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||
find_package(LAPACK REQUIRED)
|
||||
@@ -134,7 +134,7 @@ else()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include)
|
||||
message(STATUS "Lapack lib" ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||
|
@@ -1,4 +1,4 @@
|
||||
include CMakeLists.txt
|
||||
recursive-include mlx/ *
|
||||
include python/src/*
|
||||
python/mlx/py.typed # support type hinting as in PEP-561
|
||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||
|
@@ -6,8 +6,8 @@
|
||||
|
||||
[](https://circleci.com/gh/ml-explore/mlx)
|
||||
|
||||
MLX is an array framework for machine learning on Apple silicon, brought to you
|
||||
by Apple machine learning research.
|
||||
MLX is an array framework for machine learning research on Apple silicon,
|
||||
brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
||||
|
@@ -80,10 +80,8 @@ if __name__ == "__main__":
|
||||
_filter = make_predicate(args.filter, args.negative_filter)
|
||||
|
||||
if args.mlx_dtypes:
|
||||
compare_filtered = (
|
||||
lambda x: compare_mlx_dtypes(
|
||||
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
|
||||
)
|
||||
compare_filtered = lambda x: (
|
||||
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
|
||||
if _filter(x)
|
||||
else None
|
||||
)
|
||||
|
53
benchmarks/python/gather_bench.py
Normal file
53
benchmarks/python/gather_bench.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from time import time
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
|
||||
def benchmark_gather_mlx(x_shape, idx_shape):
|
||||
def gather(x, idx):
|
||||
mx.eval(x[idx])
|
||||
|
||||
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
|
||||
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||
|
||||
runtime = measure_runtime(gather, x=x, idx=idx)
|
||||
print(f"MLX: {runtime:.3f}ms")
|
||||
|
||||
|
||||
def benchmark_gather_torch(x_shape, idx_shape, device):
|
||||
def gather(x, idx, device):
|
||||
_ = x[idx]
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
|
||||
idx_shapes = [(1_000_000,), (100_000,), ()]
|
||||
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
|
||||
|
||||
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_gather_mlx(x_shape, idx_shape)
|
||||
benchmark_gather_torch(x_shape, idx_shape, device=device)
|
35
benchmarks/python/rope_bench.py
Normal file
35
benchmarks/python/rope_bench.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def time_rope():
|
||||
rope = nn.RoPE(4096)
|
||||
|
||||
# vec
|
||||
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
|
||||
mx.eval(x)
|
||||
|
||||
def rope_vec(x):
|
||||
for _ in range(32):
|
||||
x = rope(x)
|
||||
return x
|
||||
|
||||
time_fn(rope_vec, x)
|
||||
|
||||
# matrix
|
||||
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x)
|
||||
|
||||
def rope_mat(x):
|
||||
for _ in range(32):
|
||||
x = rope(x)
|
||||
return x
|
||||
|
||||
time_fn(rope_mat, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_rope()
|
56
benchmarks/python/scatter_bench.py
Normal file
56
benchmarks/python/scatter_bench.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
|
||||
def scatter(dst, x, idx):
|
||||
dst[idx] = x
|
||||
mx.eval(dst)
|
||||
|
||||
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
|
||||
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||
dst = mx.random.normal(dst_shape).astype(mx.float32)
|
||||
|
||||
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
|
||||
print(f"MLX: {runtime:.3f}ms")
|
||||
|
||||
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
|
||||
def gather(dst, x, idx, device):
|
||||
dst[idx] = x
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
|
||||
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
|
||||
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
|
||||
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
|
||||
|
||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
@@ -20,3 +20,15 @@ def time_fn(fn, *args, **kwargs):
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
print(f"{msec:.5f} msec")
|
||||
|
||||
|
||||
def measure_runtime(fn, **kwargs):
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
fn(**kwargs)
|
||||
|
||||
tic = time.time()
|
||||
iters = 100
|
||||
for _ in range(iters):
|
||||
fn(**kwargs)
|
||||
return (time.time() - tic) * 1000 / iters
|
||||
|
1
docs/.gitignore
vendored
1
docs/.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
src/python/_autosummary*/
|
||||
src/python/nn/_autosummary*/
|
||||
src/python/optimizers/_autosummary*/
|
||||
|
@@ -1,19 +0,0 @@
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. currentmodule:: {{ module }}
|
||||
|
||||
.. autoclass:: {{ objname }}
|
||||
|
||||
{#{% block methods %}
|
||||
|
||||
{% if methods %}
|
||||
.. rubric:: {{ _('Methods') }}
|
||||
|
||||
.. autosummary::
|
||||
{% for item in methods %}
|
||||
{%- if item not in inherited_members and item != '__init__' %}
|
||||
~{{ name }}.{{ item }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}#}
|
@@ -26,6 +26,7 @@ extensions = [
|
||||
|
||||
python_use_unqualified_type_names = True
|
||||
autosummary_generate = True
|
||||
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
|
||||
|
||||
intersphinx_mapping = {
|
||||
"https://docs.python.org/3": None,
|
||||
|
@@ -35,7 +35,7 @@ However, you work with vector math libraries often and realize that the
|
||||
You would really like the part of your applications that does this operation
|
||||
on the CPU to be very fast - so you decide that you want it to rely on the
|
||||
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
|
||||
our assumptions on to you, let's also assume that you want to learn how add
|
||||
our assumptions on to you, let's also assume that you want to learn how to add
|
||||
your own implementation for the gradients of your new operation while going
|
||||
over the ins-and-outs of the MLX framework.
|
||||
|
||||
@@ -677,9 +677,9 @@ Let's look at the overall directory structure first.
|
||||
Binding to Python
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We use PyBind11_ to build a Python API for the C++ library. Since bindings
|
||||
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
|
||||
are already provided, adding our :meth:`axpby` becomes very simple!
|
||||
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
|
||||
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
|
||||
already provided, adding our :meth:`axpby` is simple!
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -927,18 +927,18 @@ Results:
|
||||
|
||||
We see some modest improvements right away!
|
||||
|
||||
This operation is now good to be used to build other operations,
|
||||
in :class:`mlx.nn.Module` calls, and also as a part of graph
|
||||
transformations like :meth:`grad`!
|
||||
This operation is now good to be used to build other operations, in
|
||||
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
|
||||
:meth:`grad`!
|
||||
|
||||
Scripts
|
||||
-------
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx-examples <code>`_.
|
||||
The full example code is available in `mlx <code>`_.
|
||||
|
||||
.. code: `TODO_LINK/extensions`_
|
||||
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
|
||||
|
||||
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
||||
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
||||
|
@@ -41,6 +41,7 @@ are the CPU and GPU.
|
||||
usage/indexing
|
||||
usage/saving_and_loading
|
||||
usage/function_transforms
|
||||
usage/compile
|
||||
usage/numpy
|
||||
usage/using_streams
|
||||
|
||||
|
@@ -9,9 +9,10 @@ Devices and Streams
|
||||
:toctree: _autosummary
|
||||
|
||||
Device
|
||||
Stream
|
||||
default_device
|
||||
set_default_device
|
||||
Stream
|
||||
default_stream
|
||||
new_stream
|
||||
set_default_stream
|
||||
stream
|
||||
|
@@ -10,6 +10,8 @@ Layers
|
||||
:template: nn-module-template.rst
|
||||
|
||||
ALiBi
|
||||
AvgPool1d
|
||||
AvgPool2d
|
||||
BatchNorm
|
||||
Conv1d
|
||||
Conv2d
|
||||
@@ -22,6 +24,8 @@ Layers
|
||||
InstanceNorm
|
||||
LayerNorm
|
||||
Linear
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
Mish
|
||||
MultiHeadAttention
|
||||
PReLU
|
||||
|
@@ -18,6 +18,7 @@ Loss Functions
|
||||
kl_div_loss
|
||||
l1_loss
|
||||
log_cosh_loss
|
||||
margin_ranking_loss
|
||||
mse_loss
|
||||
nll_loss
|
||||
smooth_l1_loss
|
||||
|
@@ -11,6 +11,7 @@ Module
|
||||
:toctree: _autosummary
|
||||
|
||||
Module.training
|
||||
Module.state
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
|
@@ -29,20 +29,8 @@ model's parameters and the **optimizer state**.
|
||||
# Compute the new parameters but also the optimizer state.
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
.. toctree::
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: optimizers-template.rst
|
||||
|
||||
OptimizerState
|
||||
Optimizer
|
||||
SGD
|
||||
RMSprop
|
||||
Adagrad
|
||||
Adafactor
|
||||
AdaDelta
|
||||
Adam
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
optimizers/optimizer
|
||||
optimizers/common_optimizers
|
||||
optimizers/schedulers
|
||||
|
20
docs/src/python/optimizers/common_optimizers.rst
Normal file
20
docs/src/python/optimizers/common_optimizers.rst
Normal file
@@ -0,0 +1,20 @@
|
||||
.. _common_optimizers:
|
||||
|
||||
Common Optimizers
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: optimizers-template.rst
|
||||
|
||||
SGD
|
||||
RMSprop
|
||||
Adagrad
|
||||
Adafactor
|
||||
AdaDelta
|
||||
Adam
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
23
docs/src/python/optimizers/optimizer.rst
Normal file
23
docs/src/python/optimizers/optimizer.rst
Normal file
@@ -0,0 +1,23 @@
|
||||
Optimizer
|
||||
=========
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autoclass:: Optimizer
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Optimizer.state
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Optimizer.apply_gradients
|
||||
Optimizer.init
|
||||
Optimizer.update
|
13
docs/src/python/optimizers/schedulers.rst
Normal file
13
docs/src/python/optimizers/schedulers.rst
Normal file
@@ -0,0 +1,13 @@
|
||||
.. _schedulers:
|
||||
|
||||
Schedulers
|
||||
==========
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
step_decay
|
||||
exponential_decay
|
||||
cosine_decay
|
@@ -9,6 +9,9 @@ Transforms
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
compile
|
||||
disable_compile
|
||||
enable_compile
|
||||
grad
|
||||
value_and_grad
|
||||
jvp
|
||||
|
430
docs/src/usage/compile.rst
Normal file
430
docs/src/usage/compile.rst
Normal file
@@ -0,0 +1,430 @@
|
||||
.. _compile:
|
||||
|
||||
Compilation
|
||||
===========
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX has a :func:`compile` function transformation which compiles computation
|
||||
graphs. Function compilation results in smaller graphs by merging common work
|
||||
and fusing certain operations. In many cases this can lead to big improvements
|
||||
in run-time and memory use.
|
||||
|
||||
Getting started with :func:`compile` is simple, but there are some edge cases
|
||||
that are good to be aware of for more complex graphs and advanced usage.
|
||||
|
||||
Basics of Compile
|
||||
-----------------
|
||||
|
||||
Let's start with a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.exp(-x) + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(2.0)
|
||||
|
||||
# Regular call, no compilation
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
print(fun(x, y))
|
||||
|
||||
# Compile the function
|
||||
compiled_fun = mx.compile(fun)
|
||||
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
The output of both the regular function and the compiled function is the same
|
||||
up to numerical precision.
|
||||
|
||||
The first time you call a compiled function, MLX will build the compute
|
||||
graph, optimize it, and generate and compile code. This can be relatively
|
||||
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||
function multiple times will not initiate a new compilation. This means you
|
||||
should typically compile functions that you plan to use more than once.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.exp(-x) + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(2.0)
|
||||
|
||||
compiled_fun = mx.compile(fun)
|
||||
|
||||
# Compiled here
|
||||
compiled_fun(x, y)
|
||||
|
||||
# Not compiled again
|
||||
compiled_fun(x, y)
|
||||
|
||||
# Not compiled again
|
||||
mx.compile(fun)(x, y)
|
||||
|
||||
There are some important cases to be aware of that can cause a function to
|
||||
be recompiled:
|
||||
|
||||
* Changing the shape or number of dimensions
|
||||
* Changing the type of any of the inputs
|
||||
* Changing the number of inputs to the function
|
||||
|
||||
In certain cases only some of the compilation stack will be rerun (for
|
||||
example when changing the shapes) and in other cases the full compilation
|
||||
stack will be rerun (for example when changing the types). In general you
|
||||
should avoid compiling functions too frequently.
|
||||
|
||||
Another idiom to watch out for is compiling functions which get created and
|
||||
destroyed frequently. This can happen, for example, when compiling an anonymous
|
||||
function in a loop:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
a = mx.array(1.0)
|
||||
# Don't do this, compiles lambda at each iteration
|
||||
for _ in range(5):
|
||||
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
|
||||
|
||||
Example Speedup
|
||||
---------------
|
||||
|
||||
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
|
||||
Transformer-based models. The implementation involves several unary and binary
|
||||
element-wise operations:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
If you use this function with small arrays, it will be overhead bound. If you
|
||||
use it with large arrays it will be memory bandwidth bound. However, all of
|
||||
the operations in the ``gelu`` are fusible into a single kernel with
|
||||
:func:`compile`. This can speedup both cases considerably.
|
||||
|
||||
Let's compare the runtime of the regular function versus the compiled
|
||||
function. We'll use the following timing helper which does a warm up and
|
||||
handles synchronization:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
|
||||
def timeit(fun, x):
|
||||
# warm up
|
||||
for _ in range(10):
|
||||
mx.eval(fun(x))
|
||||
|
||||
tic = time.perf_counter()
|
||||
for _ in range(100):
|
||||
mx.eval(fun(x))
|
||||
toc = time.perf_counter()
|
||||
tpi = 1e3 * (toc - tic) / 100
|
||||
print(f"Time per iteration {tpi:.3f} (ms)")
|
||||
|
||||
|
||||
Now make an array, and benchmark both functions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x = mx.random.uniform(shape=(32, 1000, 4096))
|
||||
timeit(nn.gelu, x)
|
||||
timeit(mx.compile(nn.gelu), x)
|
||||
|
||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||
five times faster.
|
||||
|
||||
.. note::
|
||||
|
||||
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
|
||||
functions can still be helpful, but won't typically result in as large a
|
||||
speedup as compiling operations that run on the GPU.
|
||||
|
||||
|
||||
Debugging
|
||||
---------
|
||||
|
||||
When a compiled function is first called, it is traced with placeholder
|
||||
inputs. This means you can't evaluate arrays (for example to print their
|
||||
contents) inside compiled functions.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
z = -x
|
||||
print(z) # Crash
|
||||
return mx.exp(z)
|
||||
|
||||
fun(mx.array(5.0))
|
||||
|
||||
For debugging, inspecting arrays can be helpful. One way to do that is to
|
||||
globally disable compilation using the :func:`disable_compile` function or
|
||||
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
|
||||
``fun`` is compiled:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
z = -x
|
||||
print(z) # Okay
|
||||
return mx.exp(z)
|
||||
|
||||
mx.disable_compile()
|
||||
fun(mx.array(5.0))
|
||||
|
||||
|
||||
Pure Functions
|
||||
--------------
|
||||
|
||||
Compiled functions are intended to be *pure*; that is they should not have side
|
||||
effects. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
state = []
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z)
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Crash!
|
||||
print(state)
|
||||
|
||||
After the first call of ``fun``, the ``state`` list will hold a placeholder
|
||||
array. The placeholder does not have any data; it is only used to build the
|
||||
computation graph. Printing such an array results in a crash.
|
||||
|
||||
You have two options to deal with this. The first option is to simply return
|
||||
``state`` as an output:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
state = []
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z), state
|
||||
|
||||
_, state = fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
print(state)
|
||||
|
||||
In some cases returning updated state can be pretty inconvenient. Hence,
|
||||
:func:`compile` has a parameter to capture implicit outputs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from functools import partial
|
||||
|
||||
state = []
|
||||
|
||||
# Tell compile to capture state as an output
|
||||
@partial(mx.compile, outputs=state)
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z), state
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
print(state)
|
||||
|
||||
This is particularly useful for compiling a function which includes an update
|
||||
to a container of arrays, as is commonly done when training the parameters of a
|
||||
:class:`mlx.nn.Module`.
|
||||
|
||||
Compiled functions will also treat any inputs not in the parameter list as
|
||||
constants. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
state = [mx.array(1.0)]
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
return x + state[0]
|
||||
|
||||
# Prints array(2, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
# Update state
|
||||
state[0] = mx.array(5.0)
|
||||
|
||||
# Still prints array(2, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
In order to have the change of state reflected in the outputs of ``fun`` you
|
||||
again have two options. The first option is to simply pass ``state`` as input
|
||||
to the function. In some cases this can be pretty inconvenient. Hence,
|
||||
:func:`compile` also has a parameter to capture implicit inputs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from functools import partial
|
||||
state = [mx.array(1.0)]
|
||||
|
||||
# Tell compile to capture state as an input
|
||||
@partial(mx.compile, inputs=state)
|
||||
def fun(x):
|
||||
return x + state[0]
|
||||
|
||||
# Prints array(2, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
# Update state
|
||||
state[0] = mx.array(5.0)
|
||||
|
||||
# Prints array(6, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
|
||||
Compiling Training Graphs
|
||||
-------------------------
|
||||
|
||||
This section will step through how to use :func:`compile` with a simple example
|
||||
of a common setup: training a model with :obj:`mlx.nn.Module` using an
|
||||
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
|
||||
full forward, backward, and update with :func:`compile`.
|
||||
|
||||
To start, here is the simple example without any compilation:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
# 4 examples with 10 features each
|
||||
x = mx.random.uniform(shape=(4, 10))
|
||||
|
||||
# 0, 1 targets
|
||||
y = mx.array([0, 1, 0, 1])
|
||||
|
||||
# Simple linear model
|
||||
model = nn.Linear(10, 1)
|
||||
|
||||
# SGD with momentum
|
||||
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
|
||||
|
||||
def loss_fn(model, x, y):
|
||||
logits = model(x).squeeze()
|
||||
return nn.losses.binary_cross_entropy(logits, y)
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
# Perform 10 steps of gradient descent
|
||||
for it in range(10):
|
||||
loss, grads = loss_and_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
To compile the update we can put it all in a function and compile it with the
|
||||
appropriate input and output captures. Here's the same example but compiled:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from functools import partial
|
||||
|
||||
# 4 examples with 10 features each
|
||||
x = mx.random.uniform(shape=(4, 10))
|
||||
|
||||
# 0, 1 targets
|
||||
y = mx.array([0, 1, 0, 1])
|
||||
|
||||
# Simple linear model
|
||||
model = nn.Linear(10, 1)
|
||||
|
||||
# SGD with momentum
|
||||
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
|
||||
|
||||
def loss_fn(model, x, y):
|
||||
logits = model(x).squeeze()
|
||||
return nn.losses.binary_cross_entropy(logits, y)
|
||||
|
||||
# The state that will be captured as input and output
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x, y):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
# Perform 10 steps of gradient descent
|
||||
for it in range(10):
|
||||
loss = step(x, y)
|
||||
# Evaluate the model and optimizer state
|
||||
mx.eval(state)
|
||||
print(loss)
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
If you are using a module which performs random sampling such as
|
||||
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
|
||||
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
|
||||
optimizer.state, mx.random.state]``.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
For more examples of compiling full training graphs checkout the `MLX
|
||||
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
|
||||
|
||||
Transformations with Compile
|
||||
----------------------------
|
||||
|
||||
In MLX function transformations are composable. You can apply any function
|
||||
transformation to the output of any other function transformation. For more on
|
||||
this, see the documentation on :ref:`function transforms
|
||||
<function_transforms>`.
|
||||
|
||||
Compiling transformed functions works just as expected:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
grad_fn = mx.grad(mx.exp)
|
||||
|
||||
compiled_grad_fn = mx.compile(grad_fn)
|
||||
|
||||
# Prints: array(2.71828, dtype=float32)
|
||||
print(grad_fn(mx.array(1.0)))
|
||||
|
||||
# Also prints: array(2.71828, dtype=float32)
|
||||
print(compiled_grad_fn(mx.array(1.0)))
|
||||
|
||||
.. note::
|
||||
|
||||
In order to compile as much as possible, a transformation of a compiled
|
||||
function will not by default be compiled. To compile the transformed
|
||||
function simply pass it through :func:`compile`.
|
||||
|
||||
You can also compile functions which themselves call compiled functions. A
|
||||
good practice is to compile the outer most function to give :func:`compile`
|
||||
the most opportunity to optimize the computation graph:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.compile
|
||||
def inner(x):
|
||||
return mx.exp(-mx.abs(x))
|
||||
|
||||
def outer(x):
|
||||
inner(inner(x))
|
||||
|
||||
# Compiling the outer function is good to do as it will likely
|
||||
# be faster even though the inner functions are compiled
|
||||
fun = mx.compile(outer)
|
@@ -5,9 +5,12 @@ Function Transforms
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX uses composable function transformations for automatic differentiation and
|
||||
vectorization. The key idea behind composable function transformations is that
|
||||
every transformation returns a function which can be further transformed.
|
||||
MLX uses composable function transformations for automatic differentiation,
|
||||
vectorization, and compute graph optimizations. To see the complete list of
|
||||
function transformations check-out the :ref:`API documentation <transforms>`.
|
||||
|
||||
The key idea behind composable function transformations is that every
|
||||
transformation returns a function which can be further transformed.
|
||||
|
||||
Here is a simple example:
|
||||
|
||||
@@ -36,10 +39,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep
|
||||
getting higher order derivatives.
|
||||
|
||||
Any of the MLX function transformations can be composed in any order to any
|
||||
depth. To see the complete list of function transformations check-out the
|
||||
:ref:`API documentation <transforms>`. See the following sections for more
|
||||
information on :ref:`automatic differentiaion <auto diff>` and
|
||||
:ref:`automatic vectorization <vmap>`.
|
||||
depth. See the following sections for more information on :ref:`automatic
|
||||
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
|
||||
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
|
||||
|
||||
|
||||
Automatic Differentiation
|
||||
-------------------------
|
||||
|
@@ -1,4 +1,4 @@
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(mlx_sample_extensions LANGUAGES CXX)
|
||||
|
||||
@@ -63,4 +63,4 @@ target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||
endif()
|
||||
endif()
|
||||
|
@@ -3,9 +3,10 @@ target_sources(
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
|
@@ -82,6 +82,13 @@ array::array(std::initializer_list<float> data)
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
dtype)) {
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
@@ -180,7 +187,7 @@ array::ArrayDesc::ArrayDesc(
|
||||
primitive(std::move(primitive)),
|
||||
inputs(inputs) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : inputs) {
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
depth = std::max(in.graph_depth(), depth);
|
||||
}
|
||||
@@ -197,7 +204,7 @@ array::ArrayDesc::ArrayDesc(
|
||||
primitive(std::move(primitive)),
|
||||
inputs(std::move(inputs)) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : inputs) {
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
depth = std::max(in.graph_depth(), depth);
|
||||
}
|
||||
|
18
mlx/array.h
18
mlx/array.h
@@ -41,6 +41,9 @@ class array {
|
||||
/* Special case so empty lists default to float32. */
|
||||
array(std::initializer_list<float> data);
|
||||
|
||||
/* Special case so array({}, type) is an empty array. */
|
||||
array(std::initializer_list<int> data, Dtype dtype);
|
||||
|
||||
template <typename T>
|
||||
array(
|
||||
std::initializer_list<T> data,
|
||||
@@ -121,6 +124,9 @@ class array {
|
||||
template <typename T>
|
||||
T item();
|
||||
|
||||
template <typename T>
|
||||
T item() const;
|
||||
|
||||
struct ArrayIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
using difference_type = size_t;
|
||||
@@ -454,6 +460,18 @@ T array::item() {
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T array::item() const {
|
||||
if (size() != 1) {
|
||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||
}
|
||||
if (!is_evaled()) {
|
||||
throw std::invalid_argument(
|
||||
"item() const can only be called on evaled arrays");
|
||||
}
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
template <typename It>
|
||||
void array::init(It src) {
|
||||
set_data(allocator::malloc(size() * size_of(dtype())));
|
||||
|
@@ -46,6 +46,9 @@ inline void matmul_cblas_general(
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
@@ -94,6 +97,9 @@ inline void matmul_bnns_general(
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
|
@@ -33,10 +33,12 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT_MULTI(Compiled)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
@@ -57,8 +59,10 @@ DEFAULT(Minimum)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Sigmoid)
|
||||
@@ -68,8 +72,6 @@ DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT_MULTI(QRF)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
@@ -291,45 +293,6 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Avoid code duplication with the common backend.
|
||||
struct RemainderFn {
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return std::fmod(numerator, denominator);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return numerator % denominator;
|
||||
}
|
||||
};
|
||||
|
||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
RemainderFn{},
|
||||
UseDefaultBinaryOp(),
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
int num_el = n;
|
||||
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, RemainderFn{});
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
@@ -274,7 +274,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto check_input = [](array x) {
|
||||
if (x.strides()[x.ndim() - 1] == 1) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
|
@@ -3,6 +3,7 @@ target_sources(
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
@@ -10,6 +11,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
|
@@ -140,16 +140,34 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
struct RemainderFn {
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return std::fmod(numerator, denominator);
|
||||
return numerator % denominator;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(
|
||||
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
auto r = numerator % denominator;
|
||||
if (r != 0 && (r < 0 != denominator < 0))
|
||||
r += denominator;
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
auto r = std::fmod(numerator, denominator);
|
||||
if (r != 0 && (r < 0 != denominator < 0)) {
|
||||
r += denominator;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
|
||||
return numerator % denominator;
|
||||
}
|
||||
};
|
||||
|
59
mlx/backend/common/compiled.cpp
Normal file
59
mlx/backend/common/compiled.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <queue>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Build the real tape
|
||||
std::pair<std::queue<array>, std::vector<array>> trace_to_real(
|
||||
const std::vector<array>& trace_tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
const std::vector<array>& trace_outputs,
|
||||
const std::vector<array>& inputs) {
|
||||
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||
}
|
||||
std::queue<array> tape;
|
||||
for (auto& a : trace_tape) {
|
||||
// Find real inputs
|
||||
std::vector<array> real_inputs;
|
||||
for (auto& in : a.inputs()) {
|
||||
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||
}
|
||||
tape.push(
|
||||
array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)));
|
||||
trace_to_real.insert({a.id(), tape.back()});
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
for (auto& o : trace_outputs) {
|
||||
outputs.push_back(trace_to_real.at(o.id()));
|
||||
}
|
||||
return {tape, outputs};
|
||||
}
|
||||
|
||||
void Compiled::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Make the a real tape from the tracers
|
||||
auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs);
|
||||
|
||||
// Run the tape
|
||||
while (!tape.empty()) {
|
||||
auto a = std::move(tape.front());
|
||||
tape.pop();
|
||||
auto outputs = a.outputs();
|
||||
a.primitive().eval_cpu(a.inputs(), outputs);
|
||||
a.detach();
|
||||
}
|
||||
|
||||
// Copy results into outputs
|
||||
for (int o = 0; o < real_outputs.size(); ++o) {
|
||||
outputs[o].copy_shared_buffer(real_outputs[o]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -3,7 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <vecLib/cblas_new.h>
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
@@ -1,7 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <vecLib/cblas_new.h>
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
@@ -41,7 +41,9 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT_MULTI(Compiled)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
@@ -78,6 +80,7 @@ DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT(Power)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(QuantizedMatmul)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
@@ -100,8 +103,6 @@ DEFAULT(Subtract)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT_MULTI(QRF)
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -131,7 +132,9 @@ inline void matmul_common_general(
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <vecLib/lapack.h>
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
14
mlx/backend/common/rope.cpp
Normal file
14
mlx/backend/common/rope.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void RoPE::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -53,7 +53,12 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto check_input = [](array x) {
|
||||
if (x.strides().back() == 1) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
|
@@ -1,7 +1,28 @@
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND /bin/bash
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${CMAKE_C_COMPILER}
|
||||
${CMAKE_SOURCE_DIR}
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
kernels/compiled_preamble.h
|
||||
kernels/unary.h
|
||||
kernels/binary.h
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
compiled_preamble
|
||||
DEPENDS compiled_preamble.cpp
|
||||
)
|
||||
|
||||
add_dependencies(mlx compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
@@ -11,10 +32,12 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_PATH)
|
||||
|
484
mlx/backend/metal/compiled.cpp
Normal file
484
mlx/backend/metal/compiled.cpp
Normal file
@@ -0,0 +1,484 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/compiled_preamble.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline bool is_static_cast(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
|
||||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
inline auto get_type_string(Dtype d) {
|
||||
switch (d) {
|
||||
case float32:
|
||||
return "float";
|
||||
case float16:
|
||||
return "half";
|
||||
case bfloat16:
|
||||
return "bfloat16_t";
|
||||
case bool_:
|
||||
return "bool";
|
||||
case int8:
|
||||
return "int8_t";
|
||||
case int16:
|
||||
return "int16_t";
|
||||
case int32:
|
||||
return "int32_t";
|
||||
case int64:
|
||||
return "int64_t";
|
||||
case uint8:
|
||||
return "uint8_t";
|
||||
case uint16:
|
||||
return "uint16_t";
|
||||
case uint32:
|
||||
return "uint32_t";
|
||||
case uint64:
|
||||
return "uint64_t";
|
||||
default: {
|
||||
std::ostringstream msg;
|
||||
msg << "Unsupported compilation type " << d;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print_float_constant(std::ostream& os, const array& x) {
|
||||
auto old_precision = os.precision();
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||
<< x.item<T>() << std::setprecision(old_precision);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print_int_constant(std::ostream& os, const array& x) {
|
||||
os << x.item<T>();
|
||||
}
|
||||
|
||||
void print_constant(std::ostream& os, const array& x) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
return print_float_constant<float>(os, x);
|
||||
case float16:
|
||||
return print_float_constant<float16_t>(os, x);
|
||||
case bfloat16:
|
||||
return print_float_constant<bfloat16_t>(os, x);
|
||||
case int8:
|
||||
return print_int_constant<int8_t>(os, x);
|
||||
case int16:
|
||||
return print_int_constant<int16_t>(os, x);
|
||||
case int32:
|
||||
return print_int_constant<int32_t>(os, x);
|
||||
case int64:
|
||||
return print_int_constant<int64_t>(os, x);
|
||||
case uint8:
|
||||
return print_int_constant<uint8_t>(os, x);
|
||||
case uint16:
|
||||
return print_int_constant<uint16_t>(os, x);
|
||||
case uint32:
|
||||
return print_int_constant<uint32_t>(os, x);
|
||||
case uint64:
|
||||
return print_int_constant<uint64_t>(os, x);
|
||||
case bool_:
|
||||
os << std::boolalpha << x.item<bool>();
|
||||
return;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported constant type");
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
a.primitive().print(os);
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << ((x.size() == 1) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
inline void build_kernel(
|
||||
std::ostream& os,
|
||||
const std::string& kernel_name,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
bool contiguous,
|
||||
int ndim,
|
||||
bool dynamic_dims) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
// For scalar we shouldn't do the indexing things, just read at 0
|
||||
auto is_scalar = [](const array& x) { return x.size() == 1; };
|
||||
|
||||
NodeNamer namer;
|
||||
bool add_indices = false;
|
||||
int cnt = 0;
|
||||
|
||||
// Start the kernel
|
||||
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl
|
||||
<< "[[kernel]] void " << kernel_name << "(" << std::endl;
|
||||
|
||||
// Add the input arguments
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
// Skip constants from the input list
|
||||
if (is_constant(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Scalars and contiguous need no strides
|
||||
if (is_scalar(x) || contiguous) {
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
} else {
|
||||
add_indices = true;
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl
|
||||
<< " constant const size_t* " << xname << "_strides [[buffer("
|
||||
<< cnt++ << ")]]," << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
os << " device " << get_type_string(x.dtype()) << "* "
|
||||
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " constant const size_t* output_strides [[buffer(" << cnt++
|
||||
<< ")]]," << std::endl
|
||||
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],"
|
||||
<< std::endl;
|
||||
}
|
||||
if (dynamic_dims) {
|
||||
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// The thread index in the whole grid
|
||||
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
|
||||
<< " uint3 grid [[threads_per_grid]]) {" << std::endl
|
||||
<< " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"
|
||||
<< std::endl;
|
||||
|
||||
// Extract the indices per axis to individual uints if we have arrays that
|
||||
// are broadcasted or transposed
|
||||
if (add_indices) {
|
||||
if (!dynamic_dims) {
|
||||
if (ndim == 1) {
|
||||
os << " uint index_0 = pos.x;" << std::endl;
|
||||
} else if (ndim == 2) {
|
||||
os << " uint index_0 = pos.y;" << std::endl
|
||||
<< " uint index_1 = pos.x;" << std::endl;
|
||||
} else if (ndim == 3) {
|
||||
os << " uint index_0 = pos.z;" << std::endl
|
||||
<< " uint index_1 = pos.y;" << std::endl
|
||||
<< " uint index_2 = pos.x;" << std::endl;
|
||||
} else {
|
||||
for (int i = 0; i < ndim - 2; i++) {
|
||||
os << " uint index_" << i << " = (index / uint(output_strides[" << i
|
||||
<< "])) % output_shape[" << i << "];" << std::endl;
|
||||
}
|
||||
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
|
||||
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
} else if (is_scalar(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[0];" << std::endl;
|
||||
} else if (contiguous) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index];" << std::endl;
|
||||
} else if (!dynamic_dims) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[";
|
||||
os << "index_0 * " << xname << "_strides[0]";
|
||||
for (int i = 1; i < ndim; i++) {
|
||||
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
|
||||
}
|
||||
os << "];" << std::endl;
|
||||
} else {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[elem_to_loc(index, output_shape, " << xname
|
||||
<< "_strides, ndim)];" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Actually write the computation
|
||||
for (auto& x : tape) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
|
||||
<< " = ";
|
||||
if (is_static_cast(x.primitive())) {
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||
} else {
|
||||
x.primitive().print(os);
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
}
|
||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the outputs from tmps
|
||||
for (auto& x : outputs) {
|
||||
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
|
||||
<< ";" << std::endl;
|
||||
}
|
||||
|
||||
// Finish the kernel
|
||||
os << "}" << std::endl;
|
||||
|
||||
if (cnt > 31) {
|
||||
std::ostringstream msg;
|
||||
msg << "[compile] Too many inputs/outputs fused in the Metal Compile "
|
||||
<< "primitive which exhausted the available argument buffers for "
|
||||
<< "the kernel. Please file an issue with the function that results "
|
||||
<< "in this error. The name of the kernel is '" << kernel_name << "'";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void Compiled::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Make the name for the kernel library
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Get the kernel if someone else built it already
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto lib = d.get_library(kernel_lib_);
|
||||
|
||||
// If not we have to build it ourselves
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel;
|
||||
kernel << metal::get_kernel_preamble() << std::endl;
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false);
|
||||
for (int i = 1; i < 8; i++) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_strided_" + std::to_string(i),
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false);
|
||||
}
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_strided_dynamic",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true);
|
||||
|
||||
kernel_source_ = kernel.str();
|
||||
lib = d.get_library(kernel_lib_, kernel_source_);
|
||||
}
|
||||
|
||||
// Allocate space for the outputs
|
||||
for (auto& out : outputs) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
for (auto& x : inputs) {
|
||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
||||
x.size() > 1) {
|
||||
contiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
std::vector<std::vector<size_t>> initial_strides;
|
||||
initial_strides.push_back(outputs[0].strides());
|
||||
std::vector<int> shape;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
if (!contiguous) {
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
|
||||
// Skip scalar inputs.
|
||||
if (x.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
int j = 0;
|
||||
for (; j < output_shape.size() - x.ndim(); j++) {
|
||||
if (output_shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (output_shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
initial_strides.push_back(std::move(xstrides));
|
||||
}
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(output_shape, initial_strides);
|
||||
}
|
||||
|
||||
// Get the kernel from the lib
|
||||
int ndim = shape.size();
|
||||
bool dynamic = ndim >= 8;
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
if (!contiguous) {
|
||||
if (dynamic) {
|
||||
kernel_name += "dynamic";
|
||||
} else {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
if (!contiguous && x.size() > 1) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
strides[stride_idx].size() * sizeof(size_t),
|
||||
cnt++);
|
||||
stride_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
}
|
||||
|
||||
// Put the output shape and strides in
|
||||
if (!contiguous) {
|
||||
compute_encoder->setBytes(
|
||||
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
|
||||
}
|
||||
|
||||
// Put the number of dims in if it is dynamic
|
||||
if (dynamic) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
|
||||
}
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
size_t nthreads = outputs[0].size();
|
||||
MTL::Size grid_dims(nthreads, 1, 1);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
9
mlx/backend/metal/compiled_preamble.h
Normal file
9
mlx/backend/metal/compiled_preamble.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* get_kernel_preamble();
|
||||
|
||||
}
|
@@ -26,7 +26,8 @@ static constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
auto load_device() {
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0));
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||
?: MTL::CreateSystemDefaultDevice();
|
||||
if (!device) {
|
||||
throw std::runtime_error("Failed to load device");
|
||||
}
|
||||
@@ -214,15 +215,6 @@ MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
|
||||
return eit->second;
|
||||
}
|
||||
|
||||
MTL::ArgumentEncoder* Device::argument_encoder(
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const {
|
||||
// NB array here is already autoreleased but the returned argument
|
||||
// encoder is owned by the caller and must be released/autoreleased
|
||||
NS::Array* arg_desc_arr = NS::Array::array(
|
||||
reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size());
|
||||
return device_->newArgumentEncoder(arg_desc_arr);
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
@@ -413,6 +405,11 @@ MTL::ComputePipelineState* Device::get_kernel_(
|
||||
return kernel;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(const std::string& name) {
|
||||
auto it = library_map_.find(name);
|
||||
return (it != library_map_.end()) ? it->second : nullptr;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
|
@@ -62,6 +62,8 @@ class Device {
|
||||
const std::function<std::string(const std::string&)>& lib_path_func =
|
||||
get_colocated_mtllib_path);
|
||||
|
||||
MTL::Library* get_library(const std::string& name);
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::string& source_string,
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
@@ -39,114 +39,75 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
std::ostringstream kname;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
|
||||
if (idx_ndim <= 1) {
|
||||
kname << "_" << idx_ndim;
|
||||
}
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
slice_size *= s;
|
||||
}
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
size_t nthreads = out.size();
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
// Launch 2D grid of threads: indices x slice
|
||||
size_t dim0 = out.size() / slice_size;
|
||||
size_t dim1 = slice_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Make the argument buffer to store the indices for the
|
||||
// `Indices` struct in kernels/indexing.metal
|
||||
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
|
||||
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[0]->setIndex(0);
|
||||
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[0]->setArrayLength(nidx);
|
||||
|
||||
// Shapes
|
||||
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[1]->setIndex(nidx + 1);
|
||||
|
||||
// Strides
|
||||
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[2]->setIndex(nidx + 2);
|
||||
|
||||
// Indices ndim
|
||||
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
|
||||
arg_descs[3]->setIndex(nidx + 3);
|
||||
|
||||
// Get the argument encoder
|
||||
auto arg_enc = d.argument_encoder(arg_descs);
|
||||
|
||||
// Allocate and fill buffers for shapes and strides
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy(
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end(),
|
||||
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
|
||||
std::copy(
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end(),
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
// Allocate the argument buffer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
if (idx_ndim > 0) {
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
}
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, src, 0);
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
// Set source info
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
|
||||
|
||||
// Set index info
|
||||
//
|
||||
// We don't need to check for empty idx_shapes because gather has a
|
||||
// idx_ndim == 0 specialization
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Cleanup temporaries
|
||||
arg_enc->release();
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
|
||||
allocator::free(arg_buf);
|
||||
allocator::free(idx_shapes_buf);
|
||||
allocator::free(idx_strides_buf);
|
||||
});
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -211,82 +172,35 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Make the argument buffer to store the indices for the
|
||||
// `Indices` struct in kernels/indexing.metal
|
||||
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
|
||||
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[0]->setIndex(0);
|
||||
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[0]->setArrayLength(nidx);
|
||||
|
||||
// Shapes
|
||||
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[1]->setIndex(nidx + 1);
|
||||
|
||||
// Strides
|
||||
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[2]->setIndex(nidx + 2);
|
||||
|
||||
// Indices ndim
|
||||
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
|
||||
arg_descs[3]->setIndex(nidx + 3);
|
||||
|
||||
// Get the argument encoder
|
||||
auto arg_enc = d.argument_encoder(arg_descs);
|
||||
|
||||
// Allocate and fill buffers for shapes and strides
|
||||
// Collect all idx shapes and strides into one place
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy(
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end(),
|
||||
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
|
||||
std::copy(
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end(),
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
// Allocate the argument buffer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
// Register data with the encoder
|
||||
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
if (idx_ndim > 0) {
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
}
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
|
||||
// Set update info
|
||||
size_t upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
@@ -301,6 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
@@ -316,16 +231,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
}
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
|
||||
// Cleanup temporaries
|
||||
arg_enc->release();
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
|
||||
allocator::free(arg_buf);
|
||||
allocator::free(idx_shapes_buf);
|
||||
allocator::free(idx_strides_buf);
|
||||
});
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -6,6 +6,7 @@ set(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
)
|
||||
@@ -22,11 +23,13 @@ set(
|
||||
"quantized"
|
||||
"random"
|
||||
"reduce"
|
||||
"rope"
|
||||
"scan"
|
||||
"softmax"
|
||||
"sort"
|
||||
"unary"
|
||||
"indexing"
|
||||
"gather"
|
||||
"scatter"
|
||||
)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
|
231
mlx/backend/metal/kernels/binary.h
Normal file
231
mlx/backend/metal/kernels/binary.h
Normal file
@@ -0,0 +1,231 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
|
||||
metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
if (metal::isnan(x) || metal::isnan(y)) {
|
||||
return metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf)
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::min(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x != y;
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
};
|
@@ -1,176 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct Add {
|
||||
template <typename T> T operator()(T x, T y) { return x + y; }
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T> T operator()(T x, T y) { return x % y; }
|
||||
template <> float operator()(float x, float y) { return fmod(x, y); }
|
||||
template <> half operator()(half x, half y) { return fmod(x, y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T> bool operator()(T x, T y) { return x == y; }
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T> bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real)
|
||||
&& metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T> bool operator()(T x, T y) { return x > y; }
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x >= y; }
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T> bool operator()(T x, T y) { return x < y; }
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x <= y; }
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
if (metal::isnan(x) || metal::isnan(y)) {
|
||||
return metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf) ? maxval :
|
||||
(maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::min(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T> T operator()(T x, T y) { return x * y; }
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x != y; }
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct Subtract {
|
||||
template <typename T> T operator()(T x, T y) { return x - y; }
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) { return x && y; };
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) { return x || y; };
|
||||
};
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
|
@@ -14,10 +14,29 @@ struct FloorDivide {
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T> T operator()(T x, T y) { return x % y; }
|
||||
template <> float operator()(float x, float y) { return fmod(x, y); }
|
||||
template <> half operator()(half x, half y) { return fmod(x, y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <> complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
|
4
mlx/backend/metal/kernels/compiled_preamble.h
Normal file
4
mlx/backend/metal/kernels/compiled_preamble.h
Normal file
@@ -0,0 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
@@ -121,5 +121,11 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||
if (real != 0 && (real < 0 != b.real < 0)) {
|
||||
real += b.real;
|
||||
}
|
||||
if (imag != 0 && (imag < 0 != b.imag < 0)) {
|
||||
imag += b.imag;
|
||||
}
|
||||
return {real, imag};
|
||||
}
|
||||
|
187
mlx/backend/metal/kernels/gather.metal
Normal file
187
mlx/backend/metal/kernels/gather.metal
Normal file
@@ -0,0 +1,187 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
METAL_FUNC void gather_impl(
|
||||
const device T *src [[buffer(0)]],
|
||||
device T *out [[buffer(1)]],
|
||||
const constant int *src_shape [[buffer(2)]],
|
||||
const constant size_t *src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int *slice_sizes [[buffer(5)]],
|
||||
const constant int *axes [[buffer(6)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
|
||||
auto ind_idx = index.x;
|
||||
auto ind_offset = index.y;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(
|
||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
|
||||
}
|
||||
|
||||
#define make_gather_impl(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
||||
[[kernel]] void gather( \
|
||||
const device T *src [[buffer(0)]], \
|
||||
device T *out [[buffer(1)]], \
|
||||
const constant int *src_shape [[buffer(2)]], \
|
||||
const constant size_t *src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int *slice_sizes [[buffer(5)]], \
|
||||
const constant int *axes [[buffer(6)]], \
|
||||
const constant int *idx_shapes [[buffer(7)]], \
|
||||
const constant size_t *idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]) { \
|
||||
\
|
||||
Indices<IdxT, NIDX> idxs{ \
|
||||
{{IDX_ARR()}}, \
|
||||
idx_shapes, \
|
||||
idx_strides, \
|
||||
idx_ndim}; \
|
||||
\
|
||||
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
|
||||
src, \
|
||||
out, \
|
||||
src_shape, \
|
||||
src_strides, \
|
||||
src_ndim, \
|
||||
slice_sizes, \
|
||||
axes, \
|
||||
idxs, \
|
||||
index, \
|
||||
grid_dim); \
|
||||
}
|
||||
|
||||
#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
||||
|
||||
make_gather(0)
|
||||
make_gather(1)
|
||||
make_gather(2)
|
||||
make_gather(3)
|
||||
make_gather(4)
|
||||
make_gather(5)
|
||||
make_gather(6)
|
||||
make_gather(7)
|
||||
make_gather(8)
|
||||
make_gather(9)
|
||||
make_gather(10)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather instantiations
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
|
||||
template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
|
||||
[[kernel]] void gather<src_t, idx_t, nidx, nd>( \
|
||||
const device src_t *src [[buffer(0)]], \
|
||||
device src_t *out [[buffer(1)]], \
|
||||
const constant int *src_shape [[buffer(2)]], \
|
||||
const constant size_t *src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int *slice_sizes [[buffer(5)]], \
|
||||
const constant int *axes [[buffer(6)]], \
|
||||
const constant int *idx_shapes [[buffer(7)]], \
|
||||
const constant size_t *idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
|
||||
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name)
|
||||
|
||||
#define instantiate_gather4(name, src_t, idx_t, nidx) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 2, )
|
||||
|
||||
|
||||
// Special for case NIDX=0
|
||||
instantiate_gather4("bool_", bool, bool, 0)
|
||||
instantiate_gather4("uint8", uint8_t, bool, 0)
|
||||
instantiate_gather4("uint16", uint16_t, bool, 0)
|
||||
instantiate_gather4("uint32", uint32_t, bool, 0)
|
||||
instantiate_gather4("uint64", uint64_t, bool, 0)
|
||||
instantiate_gather4("int8", int8_t, bool, 0)
|
||||
instantiate_gather4("int16", int16_t, bool, 0)
|
||||
instantiate_gather4("int32", int32_t, bool, 0)
|
||||
instantiate_gather4("int64", int64_t, bool, 0)
|
||||
instantiate_gather4("float16", half, bool, 0)
|
||||
instantiate_gather4("float32", float, bool, 0)
|
||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
||||
|
||||
#define instantiate_gather3(name, src_type, ind_type) \
|
||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||
instantiate_gather4(name, src_type, ind_type, 10)
|
||||
|
||||
#define instantiate_gather(name, src_type) \
|
||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
||||
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
||||
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||
instantiate_gather3(#name "int64", src_type, int64_t)
|
||||
|
||||
instantiate_gather(bool_, bool)
|
||||
instantiate_gather(uint8, uint8_t)
|
||||
instantiate_gather(uint16, uint16_t)
|
||||
instantiate_gather(uint32, uint32_t)
|
||||
instantiate_gather(uint64, uint64_t)
|
||||
instantiate_gather(int8, int8_t)
|
||||
instantiate_gather(int16, int16_t)
|
||||
instantiate_gather(int32, int32_t)
|
||||
instantiate_gather(int64, int64_t)
|
||||
instantiate_gather(float16, half)
|
||||
instantiate_gather(float32, float)
|
||||
instantiate_gather(bfloat16, bfloat16_t)
|
54
mlx/backend/metal/kernels/indexing.h
Normal file
54
mlx/backend/metal/kernels/indexing.h
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
const constant int* shapes;
|
||||
const constant size_t* strides;
|
||||
const int ndim;
|
||||
};
|
||||
|
||||
template <typename IdxT>
|
||||
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
if (is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
}
|
||||
|
||||
#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
|
||||
|
||||
#define IDX_ARG_0(idx_t)
|
||||
#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
|
||||
#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
|
||||
#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
|
||||
#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
|
||||
#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
|
||||
#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
|
||||
#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
|
||||
#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
|
||||
#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
|
||||
#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
|
||||
|
||||
#define IDX_ARR_N(n) idx##n,
|
||||
|
||||
#define IDX_ARR_0()
|
||||
#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
|
||||
#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
|
||||
#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
|
||||
#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
|
||||
#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
|
||||
#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
|
||||
#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
|
||||
#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
|
||||
#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
|
||||
#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)
|
@@ -1,254 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_texture>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<device IdxT*, NIDX> buffers [[id(0)]];
|
||||
device int* shapes [[id(NIDX + 1)]];
|
||||
device size_t* strides [[id(NIDX + 2)]];
|
||||
const int ndim [[id(NIDX + 3)]];
|
||||
};
|
||||
|
||||
template <typename IdxT>
|
||||
inline size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(bool idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, int NIDX>
|
||||
[[kernel]] void gather(
|
||||
const device T *src [[buffer(0)]],
|
||||
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
|
||||
device T *out [[buffer(2)]],
|
||||
const device int *src_shape [[buffer(3)]],
|
||||
const device size_t *src_strides [[buffer(4)]],
|
||||
const device size_t& src_ndim [[buffer(5)]],
|
||||
const device int *slice_sizes [[buffer(6)]],
|
||||
const device size_t& slice_size [[buffer(7)]],
|
||||
const device int *axes [[buffer(8)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
|
||||
auto ind_idx = gid / slice_size;
|
||||
auto ind_offset = gid % slice_size;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(
|
||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
out[gid] = src[src_idx + src_offset];
|
||||
}
|
||||
|
||||
#define instantiate_gather4(name, src_type, ind_type, nindex) \
|
||||
template [[host_name("gather" name "_" #nindex)]] \
|
||||
[[kernel]] void gather( \
|
||||
const device src_type *src [[buffer(0)]], \
|
||||
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||
device src_type *out [[buffer(2)]], \
|
||||
const device int *src_shape [[buffer(3)]], \
|
||||
const device size_t *src_strides [[buffer(4)]], \
|
||||
const device size_t& src_ndim [[buffer(5)]], \
|
||||
const device int *slice_sizes [[buffer(6)]], \
|
||||
const device size_t& slice_size [[buffer(7)]], \
|
||||
const device int* axes [[buffer(8)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
// Special for case NIDX=0
|
||||
instantiate_gather4("bool_", bool, bool, 0)
|
||||
instantiate_gather4("uint8", uint8_t, bool, 0)
|
||||
instantiate_gather4("uint16", uint16_t, bool, 0)
|
||||
instantiate_gather4("uint32", uint32_t, bool, 0)
|
||||
instantiate_gather4("uint64", uint64_t, bool, 0)
|
||||
instantiate_gather4("int8", int8_t, bool, 0)
|
||||
instantiate_gather4("int16", int16_t, bool, 0)
|
||||
instantiate_gather4("int32", int32_t, bool, 0)
|
||||
instantiate_gather4("int64", int64_t, bool, 0)
|
||||
instantiate_gather4("float16", half, bool, 0)
|
||||
instantiate_gather4("float32", float, bool, 0)
|
||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
||||
|
||||
#define instantiate_gather3(name, src_type, ind_type) \
|
||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||
instantiate_gather4(name, src_type, ind_type, 10)
|
||||
|
||||
#define instantiate_gather(name, src_type) \
|
||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
||||
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
||||
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||
instantiate_gather3(#name "int64", src_type, int64_t)
|
||||
|
||||
instantiate_gather(bool_, bool)
|
||||
instantiate_gather(uint8, uint8_t)
|
||||
instantiate_gather(uint16, uint16_t)
|
||||
instantiate_gather(uint32, uint32_t)
|
||||
instantiate_gather(uint64, uint64_t)
|
||||
instantiate_gather(int8, int8_t)
|
||||
instantiate_gather(int16, int16_t)
|
||||
instantiate_gather(int32, int32_t)
|
||||
instantiate_gather(int64, int64_t)
|
||||
instantiate_gather(float16, half)
|
||||
instantiate_gather(float32, float)
|
||||
instantiate_gather(bfloat16, bfloat16_t)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Scatter kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
[[kernel]] void scatter(
|
||||
const device Indices<IdxT, NIDX>& indices [[buffer(0)]],
|
||||
const device T *updates [[buffer(1)]],
|
||||
device mlx_atomic<T> *out [[buffer(2)]],
|
||||
const device int *upd_shape [[buffer(3)]],
|
||||
const device size_t *upd_strides [[buffer(4)]],
|
||||
const device size_t& upd_ndim [[buffer(5)]],
|
||||
const device size_t& upd_size [[buffer(6)]],
|
||||
const device int *out_shape [[buffer(7)]],
|
||||
const device size_t *out_strides [[buffer(8)]],
|
||||
const device size_t& out_ndim [[buffer(9)]],
|
||||
const device int* axes [[buffer(10)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
auto ind_idx = gid / upd_size;
|
||||
auto ind_offset = gid % upd_size;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
||||
}
|
||||
|
||||
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
|
||||
template [[host_name("scatter" name "_" #nindex)]] \
|
||||
[[kernel]] void scatter<type, ind_type, op_type, nindex>( \
|
||||
const device Indices<ind_type, nindex>& indices [[buffer(0)]], \
|
||||
const device type *updates [[buffer(1)]], \
|
||||
device mlx_atomic<type> *out [[buffer(2)]], \
|
||||
const device int *upd_shape [[buffer(3)]], \
|
||||
const device size_t *upd_strides [[buffer(4)]], \
|
||||
const device size_t& upd_ndim [[buffer(5)]], \
|
||||
const device size_t& upd_size [[buffer(6)]], \
|
||||
const device int *out_shape [[buffer(7)]], \
|
||||
const device size_t *out_strides [[buffer(8)]], \
|
||||
const device size_t& out_ndim [[buffer(9)]], \
|
||||
const device int* axes [[buffer(10)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
|
||||
|
||||
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 3) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 4) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 5) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 6) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 10)
|
||||
|
||||
#define instantiate_scatter2(name, type, ind_type) \
|
||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
|
||||
|
||||
#define instantiate_scatter(name, type) \
|
||||
instantiate_scatter2(#name "bool_", type, bool) \
|
||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||
instantiate_scatter2(#name "uint16", type, uint16_t) \
|
||||
instantiate_scatter2(#name "uint32", type, uint32_t) \
|
||||
instantiate_scatter2(#name "uint64", type, uint64_t) \
|
||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||
instantiate_scatter2(#name "int64", type, int64_t)
|
||||
|
||||
// TODO uint64 and int64 unsupported
|
||||
instantiate_scatter_nd0(bool_, bool)
|
||||
instantiate_scatter_nd0(uint8, uint8_t)
|
||||
instantiate_scatter_nd0(uint16, uint16_t)
|
||||
instantiate_scatter_nd0(uint32, uint32_t)
|
||||
instantiate_scatter_nd0(int8, int8_t)
|
||||
instantiate_scatter_nd0(int16, int16_t)
|
||||
instantiate_scatter_nd0(int32, int32_t)
|
||||
instantiate_scatter_nd0(float16, half)
|
||||
instantiate_scatter_nd0(float32, float)
|
||||
instantiate_scatter_nd0(bfloat16, bfloat16_t)
|
||||
|
||||
instantiate_scatter(bool_, bool)
|
||||
instantiate_scatter(uint8, uint8_t)
|
||||
instantiate_scatter(uint16, uint16_t)
|
||||
instantiate_scatter(uint32, uint32_t)
|
||||
instantiate_scatter(int8, int8_t)
|
||||
instantiate_scatter(int16, int16_t)
|
||||
instantiate_scatter(int32, int32_t)
|
||||
instantiate_scatter(float16, half)
|
||||
instantiate_scatter(float32, float)
|
||||
instantiate_scatter(bfloat16, bfloat16_t)
|
@@ -15,6 +15,14 @@ using namespace metal;
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <typename T> struct AccT {
|
||||
typedef T acc_t;
|
||||
};
|
||||
|
||||
template <> struct AccT<bfloat16_t> {
|
||||
typedef float acc_t;
|
||||
};
|
||||
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
@@ -31,21 +39,23 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_thread = 32 / bits;
|
||||
constexpr int colgroup = BN * el_per_thread;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
|
||||
|
||||
threadgroup T scales_block[BM * groups_per_block];
|
||||
threadgroup T biases_block[BM * groups_per_block];
|
||||
threadgroup T x_block[colgroup];
|
||||
typedef typename AccT<T>::acc_t U;
|
||||
threadgroup U scales_block[BM * groups_per_block];
|
||||
threadgroup U biases_block[BM * groups_per_block];
|
||||
threadgroup U x_block[colgroup];
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread T result = 0;
|
||||
thread T scale = 1;
|
||||
thread T bias = 0;
|
||||
thread T x_thread[el_per_thread];
|
||||
thread U result = 0;
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_thread[el_per_thread];
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / el_per_thread;
|
||||
@@ -57,12 +67,19 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size;
|
||||
|
||||
if (out_row >= out_vec_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=colgroup) {
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid < simdgroups_fetching_vec) {
|
||||
x_block[lid] = x[lid + i];
|
||||
if (simd_gid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<el_per_thread; j++) {
|
||||
x_block[simd_lid * el_per_thread + j] = x[i + simd_lid * el_per_thread + j];
|
||||
}
|
||||
}
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -90,7 +107,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_thread; k++) {
|
||||
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
|
||||
result += (scale * static_cast<U>(w_local & bitmask) + bias) * x_thread[k];
|
||||
w_local >>= bits;
|
||||
}
|
||||
}
|
||||
@@ -100,7 +117,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
y[out_row] = result;
|
||||
y[out_row] = static_cast<T>(result);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,15 +146,16 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
constexpr int colgroup = BN * el_per_int;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
|
||||
threadgroup T scales_block[BM * groups_per_block];
|
||||
threadgroup T biases_block[BM * groups_per_block];
|
||||
threadgroup T x_block[BM];
|
||||
typedef typename AccT<T>::acc_t U;
|
||||
threadgroup U scales_block[BM * groups_per_block];
|
||||
threadgroup U biases_block[BM * groups_per_block];
|
||||
threadgroup U x_block[BM];
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread T result[el_per_int] = {0};
|
||||
thread T scale = 1;
|
||||
thread T bias = 0;
|
||||
thread T x_local = 0;
|
||||
thread U result[el_per_int] = {0};
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_local = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||
@@ -186,7 +204,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local;
|
||||
result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
|
||||
w_local >>= bits;
|
||||
}
|
||||
}
|
||||
@@ -201,7 +219,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
y[k] = result[k];
|
||||
y[k] = static_cast<T>(result[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -240,7 +258,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>;
|
||||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
||||
|
||||
|
||||
threadgroup T scales_block[BN * groups_per_block];
|
||||
threadgroup T biases_block[BN * groups_per_block];
|
||||
threadgroup T Xs[BM * BK];
|
||||
@@ -303,7 +320,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
|
||||
if (y_col + offset_col < N) {
|
||||
if (y_row + offset_row < N) {
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
@@ -418,8 +435,9 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
for (int k=0; k<K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load the x tile
|
||||
if (num_els < BM) {
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
short num_k = min(BK, K - k);
|
||||
if (num_els < BM || num_k < BK) {
|
||||
loader_x.load_safe(short2(num_k, num_els));
|
||||
} else {
|
||||
loader_x.load_unsafe();
|
||||
}
|
||||
@@ -447,7 +465,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
|
||||
// Load the w tile
|
||||
{
|
||||
if (k + BK >= K) {
|
||||
if (num_k < BK) {
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BN / el_per_int);
|
||||
|
68
mlx/backend/metal/kernels/rope.metal
Normal file
68
mlx/backend/metal/kernels/rope.metal
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional>
|
||||
[[kernel]] void rope(
|
||||
const device T *in [[buffer(0)]],
|
||||
device T * out [[buffer(1)]],
|
||||
constant const size_t strides[3],
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
constant const float& scale,
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Compute the input and output indices
|
||||
uint in_index_1, in_index_2;
|
||||
uint out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
|
||||
out_index_2 = out_index_1 + grid.x;
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
}
|
||||
|
||||
// Figure out L and d.
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * metal::exp2(-d * base);
|
||||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1 = x1 * costheta - x2 * sintheta;
|
||||
float rx2 = x1 * sintheta + x2 * costheta;
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
#define instantiate_rope(name, type, traditional) \
|
||||
template [[host_name("rope_" #name)]] \
|
||||
[[kernel]] void rope<type, traditional>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const size_t strides[3], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
instantiate_rope(traditional_float16, half, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
|
||||
instantiate_rope(traditional_float32, float, true)
|
||||
instantiate_rope(float16, half, false)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false)
|
||||
instantiate_rope(float32, float, false)
|
194
mlx/backend/metal/kernels/scatter.metal
Normal file
194
mlx/backend/metal/kernels/scatter.metal
Normal file
@@ -0,0 +1,194 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Scatter kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
METAL_FUNC void scatter_impl(
|
||||
const device T *updates [[buffer(1)]],
|
||||
device mlx_atomic<T> *out [[buffer(2)]],
|
||||
const constant int *upd_shape [[buffer(3)]],
|
||||
const constant size_t *upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int *out_shape [[buffer(7)]],
|
||||
const constant size_t *out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
auto ind_idx = gid.y;
|
||||
auto ind_offset = gid.x;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
||||
}
|
||||
|
||||
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
[[kernel]] void scatter( \
|
||||
const device T *updates [[buffer(1)]], \
|
||||
device mlx_atomic<T> *out [[buffer(2)]], \
|
||||
const constant int *upd_shape [[buffer(3)]], \
|
||||
const constant size_t *upd_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
||||
const constant size_t& upd_size [[buffer(6)]], \
|
||||
const constant int *out_shape [[buffer(7)]], \
|
||||
const constant size_t *out_strides [[buffer(8)]], \
|
||||
const constant size_t& out_ndim [[buffer(9)]], \
|
||||
const constant int* axes [[buffer(10)]], \
|
||||
const constant int *idx_shapes [[buffer(11)]], \
|
||||
const constant size_t *idx_strides [[buffer(12)]], \
|
||||
const constant int& idx_ndim [[buffer(13)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 gid [[thread_position_in_grid]]) { \
|
||||
\
|
||||
Indices<IdxT, NIDX> idxs{ \
|
||||
{{IDX_ARR()}}, \
|
||||
idx_shapes, \
|
||||
idx_strides, \
|
||||
idx_ndim}; \
|
||||
\
|
||||
return scatter_impl<T, IdxT, Op, NIDX>( \
|
||||
updates, \
|
||||
out, \
|
||||
upd_shape, \
|
||||
upd_strides, \
|
||||
upd_ndim, \
|
||||
upd_size, \
|
||||
out_shape, \
|
||||
out_strides, \
|
||||
out_ndim, \
|
||||
axes, \
|
||||
idxs, \
|
||||
gid); \
|
||||
}
|
||||
|
||||
#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
||||
|
||||
make_scatter(0)
|
||||
make_scatter(1)
|
||||
make_scatter(2)
|
||||
make_scatter(3)
|
||||
make_scatter(4)
|
||||
make_scatter(5)
|
||||
make_scatter(6)
|
||||
make_scatter(7)
|
||||
make_scatter(8)
|
||||
make_scatter(9)
|
||||
make_scatter(10)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Scatter instantiations
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
||||
template [[host_name("scatter" name "_" #nidx)]] \
|
||||
[[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \
|
||||
const device src_t *updates [[buffer(1)]], \
|
||||
device mlx_atomic<src_t> *out [[buffer(2)]], \
|
||||
const constant int *upd_shape [[buffer(3)]], \
|
||||
const constant size_t *upd_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
||||
const constant size_t& upd_size [[buffer(6)]], \
|
||||
const constant int *out_shape [[buffer(7)]], \
|
||||
const constant size_t *out_strides [[buffer(8)]], \
|
||||
const constant size_t& out_ndim [[buffer(9)]], \
|
||||
const constant int* axes [[buffer(10)]], \
|
||||
const constant int *idx_shapes [[buffer(11)]], \
|
||||
const constant size_t *idx_strides [[buffer(12)]], \
|
||||
const constant int& idx_ndim [[buffer(13)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
|
||||
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
|
||||
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
|
||||
|
||||
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 3) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 4) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 5) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 6) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 10)
|
||||
|
||||
#define instantiate_scatter2(name, type, ind_type) \
|
||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
|
||||
|
||||
#define instantiate_scatter(name, type) \
|
||||
instantiate_scatter2(#name "bool_", type, bool) \
|
||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||
instantiate_scatter2(#name "uint16", type, uint16_t) \
|
||||
instantiate_scatter2(#name "uint32", type, uint32_t) \
|
||||
instantiate_scatter2(#name "uint64", type, uint64_t) \
|
||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||
instantiate_scatter2(#name "int64", type, int64_t)
|
||||
|
||||
// TODO uint64 and int64 unsupported
|
||||
instantiate_scatter_nd0(bool_, bool)
|
||||
instantiate_scatter_nd0(uint8, uint8_t)
|
||||
instantiate_scatter_nd0(uint16, uint16_t)
|
||||
instantiate_scatter_nd0(uint32, uint32_t)
|
||||
instantiate_scatter_nd0(int8, int8_t)
|
||||
instantiate_scatter_nd0(int16, int16_t)
|
||||
instantiate_scatter_nd0(int32, int32_t)
|
||||
instantiate_scatter_nd0(float16, half)
|
||||
instantiate_scatter_nd0(float32, float)
|
||||
instantiate_scatter_nd0(bfloat16, bfloat16_t)
|
||||
|
||||
instantiate_scatter(bool_, bool)
|
||||
instantiate_scatter(uint8, uint8_t)
|
||||
instantiate_scatter(uint16, uint16_t)
|
||||
instantiate_scatter(uint32, uint32_t)
|
||||
instantiate_scatter(int8, int8_t)
|
||||
instantiate_scatter(int16, int16_t)
|
||||
instantiate_scatter(int32, int32_t)
|
||||
instantiate_scatter(float16, half)
|
||||
instantiate_scatter(float32, float)
|
||||
instantiate_scatter(bfloat16, bfloat16_t)
|
@@ -89,20 +89,9 @@ struct GEMMKernel {
|
||||
// Appease the compiler
|
||||
(void)l;
|
||||
|
||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
|
||||
if (!M_aligned) {
|
||||
short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
loader_a.set_mask(tile_dims_A, mask_A);
|
||||
}
|
||||
|
||||
if (!N_aligned) {
|
||||
short2 tile_dims_B =
|
||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
loader_b.set_mask(tile_dims_B, mask_B);
|
||||
}
|
||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@@ -110,13 +99,13 @@ struct GEMMKernel {
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(mask_B);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@@ -137,11 +126,8 @@ struct GEMMKernel {
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.set_mask(tile_dims_A_last, mask_A);
|
||||
loader_b.set_mask(tile_dims_B_last, mask_B);
|
||||
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_b.load_safe(mask_B);
|
||||
loader_a.load_safe(tile_dims_A_last);
|
||||
loader_b.load_safe(tile_dims_B_last);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@@ -218,14 +204,8 @@ struct GEMMKernel {
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||
|
||||
loader_a.set_mask(tile_dims_A, mask_A);
|
||||
loader_b.set_mask(tile_dims_B, mask_B);
|
||||
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_b.load_safe(mask_B);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
|
@@ -112,14 +112,8 @@ template <typename T,
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||
|
||||
loader_a.set_mask(tile_dims_A, mask_A);
|
||||
loader_b.set_mask(tile_dims_B, mask_B);
|
||||
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_b.load_safe(mask_B);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
|
@@ -67,24 +67,22 @@ struct BlockLoader {
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void set_mask(
|
||||
thread const short2& src_tile_dims,
|
||||
thread bool mask[n_rows][vec_size]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
mask[i][j] =
|
||||
((bi + i) < src_tile_dims.y) && ((bj + j) < src_tile_dims.x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
@@ -117,39 +115,6 @@ struct BlockLoader {
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(const thread bool mask[n_rows][vec_size]) const {
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, ii = 0; i < BROWS; i += TROWS, ii++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Use fast thread memory for bound checks
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(mask[ii][j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = mask[ii][j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
|
376
mlx/backend/metal/kernels/unary.h
Normal file
376
mlx/backend/metal/kernels/unary.h
Normal file
@@ -0,0 +1,376 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::abs(x);
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::acos(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::acosh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::asin(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::asinh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::atan(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::atanh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::ceil(x);
|
||||
};
|
||||
template <>
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::cos(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::cosh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(erf(static_cast<float>(x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(erfinv(static_cast<float>(x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::exp(x);
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto m = metal::precise::exp(x.real);
|
||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::floor(x);
|
||||
};
|
||||
template <>
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::log(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::log2(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::log10(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return log1p(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return !x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return -x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::rint(x);
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {metal::rint(x.real), metal::rint(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x != 0;
|
||||
};
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::sin(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::sinh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x * x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::sqrt(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::rsqrt(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::tan(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tan_a = metal::precise::tan(x.real);
|
||||
float tanh_b = metal::precise::tanh(x.imag);
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
||||
};
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::tanh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tanh_a = metal::precise::tanh(x.real);
|
||||
float tan_b = metal::precise::tan(x.imag);
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||
};
|
||||
};
|
@@ -1,223 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct Abs {
|
||||
template <typename T> T operator()(T x) { return metal::abs(x); };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
template <> complex64_t operator()(complex64_t x) {
|
||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::acos(x); };
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::acosh(x); };
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T> T operator()(T x) { return metal::precise::asin(x); };
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::asinh(x); };
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atan(x); };
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T> T operator()(T x) { return metal::ceil(x); };
|
||||
template <> int8_t operator()(int8_t x) { return x; };
|
||||
template <> int16_t operator()(int16_t x) { return x; };
|
||||
template <> int32_t operator()(int32_t x) { return x; };
|
||||
template <> int64_t operator()(int64_t x) { return x; };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cosh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T> T operator()(T x) { return static_cast<T>(erf(static_cast<float>(x))); };
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T> T operator()(T x) { return static_cast<T>(erfinv(static_cast<float>(x))); };
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T> T operator()(T x) { return metal::precise::exp(x); };
|
||||
template <> complex64_t operator()(complex64_t x) {
|
||||
auto m = metal::precise::exp(x.real);
|
||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T> T operator()(T x) { return metal::floor(x); };
|
||||
template <> int8_t operator()(int8_t x) { return x; };
|
||||
template <> int16_t operator()(int16_t x) { return x; };
|
||||
template <> int32_t operator()(int32_t x) { return x; };
|
||||
template <> int64_t operator()(int64_t x) { return x; };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log(x); };
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log2(x); };
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log10(x); };
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T> T operator()(T x) { return log1p(x); };
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T> T operator()(T x) { return !x; };
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T> T operator()(T x) { return -x; };
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T> T operator()(T x) { return metal::rint(x); };
|
||||
template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; };
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T> T operator()(T x) { return (x > T(0)) - (x < T(0)); };
|
||||
template <> uint32_t operator()(uint32_t x) { return x != 0; };
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sin(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sinh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T> T operator()(T x) { return x * x; };
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sqrt(x); };
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T> T operator()(T x) { return metal::precise::rsqrt(x); };
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T> T operator()(T x) { return metal::precise::tan(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tan_a = metal::precise::tan(x.real);
|
||||
float tanh_b = metal::precise::tanh(x.imag);
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {
|
||||
(tan_a - tanh_b * t1) / denom,
|
||||
(tanh_b + tan_a * t1) / denom
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::tanh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tanh_a = metal::precise::tanh(x.real);
|
||||
float tan_b = metal::precise::tan(x.imag);
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {
|
||||
(tanh_a + tan_b * t1) / denom,
|
||||
(tan_b - tanh_a * t1) / denom
|
||||
};
|
||||
};
|
||||
};
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void unary_op_v(
|
||||
|
@@ -12,10 +12,10 @@
|
||||
|
||||
template <typename U>
|
||||
struct Limits {
|
||||
static const constant U max;
|
||||
static const constant U min;
|
||||
static const constant U finite_max;
|
||||
static const constant U finite_min;
|
||||
static const constant U max = metal::numeric_limits<U>::max();
|
||||
static const constant U min = metal::numeric_limits<U>::min();
|
||||
static const constant U finite_max = metal::numeric_limits<U>::max();
|
||||
static const constant U finite_min = metal::numeric_limits<U>::min();
|
||||
};
|
||||
|
||||
#define instantiate_default_limit(type) \
|
||||
@@ -71,7 +71,7 @@ inline size_t elem_to_loc(
|
||||
device const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
@@ -84,7 +84,7 @@ inline size_t elem_to_loc(
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
@@ -273,4 +273,4 @@ inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
||||
|
||||
inline bool simd_shuffle_down(bool data, uint16_t delta) {
|
||||
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
||||
}
|
||||
}
|
||||
|
28
mlx/backend/metal/make_compiled_preamble.sh
Normal file
28
mlx/backend/metal/make_compiled_preamble.sh
Normal file
@@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script generates a C++ function that provides the Metal unary and binary
|
||||
# ops at runtime for use with kernel generation.
|
||||
#
|
||||
# Copyright © 2023-24 Apple Inc.
|
||||
|
||||
|
||||
OUTPUT_FILE=$1
|
||||
CC=$2
|
||||
SRCDIR=$3
|
||||
|
||||
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* get_kernel_preamble() {
|
||||
return R"preamble(
|
||||
$CONTENT
|
||||
)preamble";
|
||||
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
EOF
|
@@ -63,7 +63,15 @@ std::function<void()> make_task(
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
auto outputs = arr.outputs();
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.push_back(in.data_shared_ptr());
|
||||
|
@@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int bo = std::min(32, O);
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
set_array_buffer(compute_encoder, scales, 1);
|
||||
|
55
mlx/backend/metal/rope.cpp
Normal file
55
mlx/backend/metal/rope.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void RoPE::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() != 3) {
|
||||
throw std::runtime_error(
|
||||
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
|
||||
}
|
||||
if (dims_ != in.shape(-1)) {
|
||||
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
|
||||
}
|
||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
float base = std::log2(base_);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, donated ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&base, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 5);
|
||||
|
||||
int dim0 = in.shape(2) / 2;
|
||||
int dim1 = in.shape(1);
|
||||
int dim2 = in.shape(0);
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -22,7 +22,12 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
if (x.strides()[x.ndim() - 1] == 1) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
|
@@ -9,20 +9,6 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
MTL::ArgumentEncoder* enc,
|
||||
const array& a,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
// MTL::Resource usage through argument buffer needs to be explicitly
|
||||
// flagged to enable hazard tracking
|
||||
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
|
||||
}
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
@@ -117,16 +103,18 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>> strides) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
if (xs[0].ndim() > 0) {
|
||||
if (shape.size() > 0) {
|
||||
to_collapse.push_back(0);
|
||||
for (int i = 1; i < xs[0].ndim(); i++) {
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
for (auto& x : xs) {
|
||||
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
|
||||
for (const std::vector<size_t>& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1]) {
|
||||
contiguous = false;
|
||||
}
|
||||
if (!contiguous) {
|
||||
@@ -142,21 +130,31 @@ collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<size_t>> out_strides(xs.size());
|
||||
std::vector<std::vector<size_t>> out_strides(strides.size());
|
||||
for (int i = 0; i < to_collapse.size(); i++) {
|
||||
int current_shape = xs[0].shape()[to_collapse[i]];
|
||||
int current_shape = shape[to_collapse[i]];
|
||||
while (to_collapse[++i] != -1) {
|
||||
current_shape *= xs[0].shape()[to_collapse[i]];
|
||||
current_shape *= shape[to_collapse[i]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < xs.size(); j++) {
|
||||
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<size_t>& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (auto& x : xs) {
|
||||
strides.emplace_back(x.strides());
|
||||
}
|
||||
return collapse_contiguous_dims(xs[0].shape(), strides);
|
||||
}
|
||||
|
||||
template <typename... Arrays>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(Arrays... xs) {
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/fast.h"
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
@@ -32,6 +33,7 @@ NO_GPU(AsType)
|
||||
NO_GPU(AsStrided)
|
||||
NO_GPU(Broadcast)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Concatenate)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Copy)
|
||||
@@ -40,6 +42,7 @@ NO_GPU(Cosh)
|
||||
NO_GPU_MULTI(CustomVJP)
|
||||
NO_GPU_MULTI(Depends)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(Remainder)
|
||||
NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
@@ -69,6 +72,7 @@ NO_GPU(NotEqual)
|
||||
NO_GPU(Pad)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU(Power)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(RandomBits)
|
||||
NO_GPU(Reduce)
|
||||
@@ -91,6 +95,9 @@ NO_GPU(Subtract)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Transpose)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU_MULTI(QRF)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(RoPE)
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
442
mlx/compile.cpp
442
mlx/compile.cpp
@@ -1,36 +1,168 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace detail {
|
||||
constexpr int max_compile_depth = 10;
|
||||
|
||||
bool& compiler_disabled() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
static bool compiler_disabled_ = get_val();
|
||||
return compiler_disabled_;
|
||||
bool is_unary(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) ||
|
||||
typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) ||
|
||||
typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
|
||||
typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
|
||||
typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
|
||||
typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) ||
|
||||
typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) ||
|
||||
typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) ||
|
||||
typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) ||
|
||||
typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) ||
|
||||
typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) ||
|
||||
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
|
||||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
|
||||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
|
||||
typeid(p) == typeid(Tanh));
|
||||
}
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
bool is_binary(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) ||
|
||||
typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) ||
|
||||
typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) ||
|
||||
typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) ||
|
||||
typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) ||
|
||||
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
|
||||
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
|
||||
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
|
||||
typeid(p) == typeid(Subtract));
|
||||
}
|
||||
|
||||
bool is_broadcast(const Primitive& p) {
|
||||
return typeid(p) == typeid(Broadcast);
|
||||
}
|
||||
|
||||
bool is_noop(const Primitive& p) {
|
||||
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
|
||||
}
|
||||
|
||||
bool is_fusable(const Primitive& p) {
|
||||
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
std::vector<array> compile_replace(
|
||||
const std::vector<array>& tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
const std::vector<array>& trace_outputs,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
Compiled::Compiled(
|
||||
Stream stream,
|
||||
std::vector<array> inputs,
|
||||
std::vector<array> outputs,
|
||||
std::vector<array> tape,
|
||||
std::unordered_set<uintptr_t> constant_ids)
|
||||
: Primitive(stream),
|
||||
inputs_(std::move(inputs)),
|
||||
outputs_(std::move(outputs)),
|
||||
tape_(std::move(tape)),
|
||||
constant_ids_(std::move(constant_ids)) {}
|
||||
|
||||
std::vector<array> Compiled::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
|
||||
}
|
||||
|
||||
std::vector<array> Compiled::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
|
||||
}
|
||||
|
||||
bool Compiled::is_equivalent(const Primitive& other) const {
|
||||
const Compiled& a_other = static_cast<const Compiled&>(other);
|
||||
return std::equal(
|
||||
tape_.begin(),
|
||||
tape_.end(),
|
||||
a_other.tape_.begin(),
|
||||
a_other.tape_.end(),
|
||||
[](const array& a1, const array& a2) {
|
||||
auto& p1 = a1.primitive();
|
||||
auto& p2 = a2.primitive();
|
||||
return typeid(p1) == typeid(p2) && p1.is_equivalent(p2);
|
||||
});
|
||||
}
|
||||
|
||||
void Compiled::print(std::ostream& os) {
|
||||
os << "Compiled";
|
||||
for (auto& a : tape_) {
|
||||
a.primitive().print(os);
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
CompileMode& compile_mode() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
|
||||
return CompileMode::disabled;
|
||||
} else {
|
||||
return CompileMode::enabled;
|
||||
}
|
||||
};
|
||||
static CompileMode compile_mode_ = get_val();
|
||||
return compile_mode_;
|
||||
}
|
||||
|
||||
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
||||
// Helper that merges two arrays in the graph by setting the parents of the
|
||||
// source to point to the destination
|
||||
void merge(array& dst, array& src, ParentsMap& parents_map) {
|
||||
// Canonicalize the order of the primitives outputs
|
||||
auto sources = src.outputs();
|
||||
auto dests = dst.outputs();
|
||||
// For each src parent, point it to the corresponding dst
|
||||
for (int i = 0; i < sources.size(); ++i) {
|
||||
auto src_parents = parents_map.find(sources[i].id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& pairs = parents_map[dests[i].id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dests[i];
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename... U>
|
||||
size_t getAddress(std::function<T(U...)> f) {
|
||||
typedef T(fnType)(U...);
|
||||
@@ -59,9 +191,7 @@ struct CompilerCache {
|
||||
auto is_match = [](const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
if (in1.size() != in2.size()) {
|
||||
throw std::runtime_error(
|
||||
"[compiler] Got different number of inputs to function,"
|
||||
" this should never happen.");
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < in1.size(); ++i) {
|
||||
if (in1[i].shape() != in2[i].shape()) {
|
||||
@@ -205,28 +335,6 @@ void compile_simplify(
|
||||
}
|
||||
}
|
||||
|
||||
// Helper that fuses two arrays in the graph by setting the parents of the
|
||||
// source to point to the destination
|
||||
auto fuse = [&](array& dst, array& src) {
|
||||
// Canonicalize the order of the primitives outputs
|
||||
auto sources = src.outputs();
|
||||
auto dests = dst.outputs();
|
||||
// For each src parent, point it to the corresponding dest
|
||||
for (int i = 0; i < sources.size(); ++i) {
|
||||
auto src_parents = parents_map.find(sources[i].id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& pairs = parents_map[dests[i].id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dests[i];
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
}
|
||||
};
|
||||
|
||||
// Depth-1 array equivalence check.
|
||||
auto array_equivalent = [](const array& a, const array& b) {
|
||||
if (!a.has_primitive() || !b.has_primitive()) {
|
||||
@@ -254,33 +362,32 @@ void compile_simplify(
|
||||
return pa.is_equivalent(pb);
|
||||
};
|
||||
|
||||
// Pass 0: fuse scalars
|
||||
// Merge scalars
|
||||
std::vector<array> new_tape;
|
||||
for (auto& arr : tape) {
|
||||
// Check if we can fuse scalars
|
||||
// Check if we can merge scalars
|
||||
if (is_scalar(arr)) {
|
||||
auto scalar = scalars.find(get_scalar_rep(arr));
|
||||
if (scalar->second.id() != arr.id()) {
|
||||
fuse(scalar->second, arr);
|
||||
merge(scalar->second, arr, parents_map);
|
||||
// Don't keep orphaned scalars in the tape
|
||||
continue;
|
||||
}
|
||||
}
|
||||
new_tape.push_back(std::move(arr));
|
||||
}
|
||||
|
||||
tape = std::move(new_tape);
|
||||
|
||||
std::unordered_set<uintptr_t> output_set;
|
||||
for (auto& o : outputs) {
|
||||
output_set.insert(o.id());
|
||||
}
|
||||
// Pass 1..passes: fuse only keeping non-orphaned arrays in the tape
|
||||
// Multi-pass merge only keeping non-orphaned arrays in the tape
|
||||
for (int pass = 0; pass < passes; ++pass) {
|
||||
for (auto& arr : tape) {
|
||||
// Helper to check if we can fuse the parents of the
|
||||
// Helper to check if we can merge the parents of the
|
||||
// given array
|
||||
auto maybe_fuse_parents = [&](auto& a) {
|
||||
auto maybe_merge_parents = [&](auto& a) {
|
||||
auto parents = parents_map.find(a.id());
|
||||
if (parents != parents_map.end()) {
|
||||
auto N = parents->second.size();
|
||||
@@ -296,7 +403,7 @@ void compile_simplify(
|
||||
auto& src = parents->second[j].first;
|
||||
auto& dst = parents->second[i].first;
|
||||
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
||||
fuse(dst, src);
|
||||
merge(dst, src, parents_map);
|
||||
mask[j] = true;
|
||||
}
|
||||
}
|
||||
@@ -313,9 +420,9 @@ void compile_simplify(
|
||||
}
|
||||
};
|
||||
|
||||
bool discard = maybe_fuse_parents(arr);
|
||||
bool discard = maybe_merge_parents(arr);
|
||||
for (auto& s : arr.siblings()) {
|
||||
discard &= maybe_fuse_parents(s);
|
||||
discard &= maybe_merge_parents(s);
|
||||
}
|
||||
// If an array and its siblings have no parents, and none of them are
|
||||
// outputs, it is safe to remove it from the tape
|
||||
@@ -327,6 +434,216 @@ void compile_simplify(
|
||||
}
|
||||
}
|
||||
|
||||
// Extract sub-graphs of the graph that can be compiled
|
||||
// and replace them with a Compiled Primitive.
|
||||
void compile_fuse(
|
||||
std::vector<array>& tape,
|
||||
ParentsMap& parents_map,
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Track outputs to replace with new compiled outputs
|
||||
std::unordered_map<uintptr_t, array> output_map;
|
||||
for (auto& o : outputs) {
|
||||
output_map.insert({o.id(), o});
|
||||
}
|
||||
|
||||
// Set of inputs to distinguish constants
|
||||
std::unordered_set<uintptr_t> input_ids;
|
||||
for (auto& in : inputs) {
|
||||
input_ids.insert(in.id());
|
||||
}
|
||||
|
||||
// Go through the tape in reverse order and check for fusable sub-graphs
|
||||
std::vector<array> new_tape;
|
||||
std::unordered_set<uintptr_t> global_cache;
|
||||
for (int i = tape.size() - 1; i >= 0; --i) {
|
||||
auto& arr = tape[i];
|
||||
|
||||
// Already compiled
|
||||
if (global_cache.find(arr.id()) != global_cache.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Two pass recursion:
|
||||
// First pass:
|
||||
// - Collect all the primitives which we can fuse with
|
||||
// - Keeps a cache of fusable primitives which may be added out of
|
||||
// DAG order. We have to determine if all of a fused primitive's
|
||||
// outputs are also in the fused section, and this may not be the
|
||||
// case the first time we visit it.
|
||||
// Second pass:
|
||||
// - Collect inputs to the new compiled primitive
|
||||
// - Add fusable primitives to a tape in the correct order
|
||||
|
||||
std::function<void(const array&, int, const Stream&)> recurse;
|
||||
std::unordered_set<uintptr_t> cache;
|
||||
recurse = [&](const array& a, int depth, const Stream& s) {
|
||||
if (cache.find(a.id()) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Stop fusing if:
|
||||
// - Depth limit exceeded
|
||||
// - Constant input
|
||||
// - Stream mismatch
|
||||
// - Non fusable primitive
|
||||
if (depth >= max_compile_depth || !a.has_primitive() ||
|
||||
a.primitive().stream() != s || !is_fusable(a.primitive())) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool all_parents_in = true;
|
||||
if (depth > 0) {
|
||||
// Guaranteed to have a parent since nested in the
|
||||
// recursion.
|
||||
auto& parents = parents_map.at(a.id());
|
||||
for (auto& [p, idx] : parents) {
|
||||
auto in_cache = cache.find(p.id()) != cache.end();
|
||||
if (!in_cache) {
|
||||
all_parents_in = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Arrays with a mix of parents outside the compilable section
|
||||
// are not fusable
|
||||
if (!all_parents_in) {
|
||||
return;
|
||||
}
|
||||
|
||||
cache.insert({a.id()});
|
||||
|
||||
for (auto& in : a.inputs()) {
|
||||
recurse(in, depth + 1, s);
|
||||
}
|
||||
};
|
||||
|
||||
if (arr.has_primitive()) {
|
||||
Stream s = arr.primitive().stream();
|
||||
recurse(arr, 0, s);
|
||||
}
|
||||
|
||||
// Not worth fusing a single primitive
|
||||
if (cache.size() <= 1) {
|
||||
new_tape.push_back(arr);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Recurse a second time to build the tape in the right
|
||||
// order and collect the inputs
|
||||
std::unordered_set<uintptr_t> input_set;
|
||||
std::vector<array> inputs;
|
||||
std::vector<array> fused_tape;
|
||||
std::unordered_set<uintptr_t> tape_set;
|
||||
std::function<void(const array&)> recurse_tape;
|
||||
recurse_tape = [&](const array& a) {
|
||||
if (cache.find(a.id()) == cache.end()) {
|
||||
if (input_set.find(a.id()) == input_set.end()) {
|
||||
input_set.insert(a.id());
|
||||
inputs.push_back(a);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (tape_set.find(a.id()) != tape_set.end()) {
|
||||
return;
|
||||
}
|
||||
tape_set.insert(a.id());
|
||||
for (auto& in : a.inputs()) {
|
||||
recurse_tape(in);
|
||||
}
|
||||
fused_tape.push_back(a);
|
||||
};
|
||||
recurse_tape(arr);
|
||||
|
||||
std::vector<array> old_outputs;
|
||||
// Add to global cache and add any global outputs to outputs
|
||||
// of new primitive
|
||||
for (int j = 0; j < fused_tape.size() - 1; ++j) {
|
||||
auto& f = fused_tape[j];
|
||||
if (output_map.find(f.id()) != output_map.end()) {
|
||||
old_outputs.push_back(f);
|
||||
// Parents are now siblings, update the parent map
|
||||
auto& pairs = parents_map[f.id()];
|
||||
pairs.erase(
|
||||
std::remove_if(
|
||||
pairs.begin(),
|
||||
pairs.end(),
|
||||
[&](auto& p) {
|
||||
return cache.find(p.first.id()) != cache.end();
|
||||
}),
|
||||
pairs.end());
|
||||
} else {
|
||||
// Remove inner fused arrays parents from the parents map
|
||||
// to keep the parents map in a valid state
|
||||
parents_map.erase(f.id());
|
||||
}
|
||||
global_cache.insert({f.id()});
|
||||
}
|
||||
old_outputs.push_back(arr);
|
||||
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Dtype> types;
|
||||
for (auto& o : old_outputs) {
|
||||
shapes.push_back(o.shape());
|
||||
types.push_back(o.dtype());
|
||||
}
|
||||
std::unordered_set<uintptr_t> constant_ids;
|
||||
for (auto& in : inputs) {
|
||||
// Scalar constant
|
||||
if (in.size() == 1 && !in.has_primitive() &&
|
||||
input_ids.find(in.id()) == input_ids.end()) {
|
||||
constant_ids.insert(in.id());
|
||||
}
|
||||
}
|
||||
auto compiled_outputs = array::make_arrays(
|
||||
shapes,
|
||||
types,
|
||||
std::make_shared<Compiled>(
|
||||
old_outputs.back().primitive().stream(),
|
||||
inputs,
|
||||
old_outputs,
|
||||
std::move(fused_tape),
|
||||
std::move(constant_ids)),
|
||||
inputs);
|
||||
|
||||
// One output per primitive
|
||||
new_tape.push_back(compiled_outputs.back());
|
||||
|
||||
// Replace inputs old parents with compiled_outputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& pairs = parents_map[inputs[i].id()];
|
||||
pairs.erase(
|
||||
std::remove_if(
|
||||
pairs.begin(),
|
||||
pairs.end(),
|
||||
[&](auto& p) { return cache.find(p.first.id()) != cache.end(); }),
|
||||
pairs.end());
|
||||
for (auto& o : compiled_outputs) {
|
||||
pairs.push_back({o, i});
|
||||
}
|
||||
}
|
||||
|
||||
// - Update outputs parents to point to compiled outputs
|
||||
// - Update any overall graph outputs to be compiled outputs
|
||||
for (int o = 0; o < old_outputs.size(); ++o) {
|
||||
merge(compiled_outputs[o], old_outputs[o], parents_map);
|
||||
if (auto it = output_map.find(old_outputs[o].id());
|
||||
it != output_map.end()) {
|
||||
it->second = compiled_outputs[o];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::reverse(new_tape.begin(), new_tape.end());
|
||||
tape = std::move(new_tape);
|
||||
|
||||
// Replace output with potentially compiled output
|
||||
for (auto& o : outputs) {
|
||||
o = output_map.at(o.id());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> compile_replace(
|
||||
const std::vector<array>& tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
@@ -380,10 +697,17 @@ std::vector<array> compile_replace(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id) {
|
||||
if (compiler_disabled()) {
|
||||
if (compile_mode() == CompileMode::disabled) {
|
||||
return fun;
|
||||
}
|
||||
return [fun, fun_id](const std::vector<array>& inputs) {
|
||||
// If the inputs are tracers, trace the original graph
|
||||
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
|
||||
return in.is_tracer();
|
||||
})) {
|
||||
return fun(inputs);
|
||||
}
|
||||
|
||||
// Find a cache entry with the correct inputs
|
||||
auto& entry = compiler_cache().find(fun_id, inputs);
|
||||
|
||||
@@ -402,10 +726,16 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
compile_dfs(entry.inputs, entry.outputs);
|
||||
|
||||
// Simplify the tape
|
||||
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3);
|
||||
if (compile_mode() != CompileMode::no_simplify) {
|
||||
compile_simplify(
|
||||
entry.tape, parents_map, entry.outputs, /* passes */ 3);
|
||||
}
|
||||
|
||||
// This is a good point to do more optimizations, e.g. kernel fusion to
|
||||
// generate new primitives. The tape needs to be updated accordingly
|
||||
// Kernel fusion to generate Compiled primitives. The tape and
|
||||
// new outputs must be updated accordingly
|
||||
if (compile_mode() != CompileMode::no_fuse) {
|
||||
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
|
||||
}
|
||||
}
|
||||
|
||||
// At this point we must have a tape, now replace the placeholders
|
||||
@@ -422,7 +752,7 @@ void compile_erase(size_t fun_id) {
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
|
||||
if (detail::compiler_disabled()) {
|
||||
if (detail::compile_mode() == CompileMode::disabled) {
|
||||
return fun;
|
||||
}
|
||||
auto fun_id = detail::getAddress(fun);
|
||||
@@ -430,11 +760,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
}
|
||||
|
||||
void disable_compile() {
|
||||
detail::compiler_disabled() = true;
|
||||
detail::compile_mode() = CompileMode::disabled;
|
||||
}
|
||||
|
||||
void enable_compile() {
|
||||
detail::compiler_disabled() = false;
|
||||
detail::compile_mode() = CompileMode::enabled;
|
||||
}
|
||||
|
||||
void set_compile_mode(CompileMode mode) {
|
||||
detail::compile_mode() = mode;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
28
mlx/compile.h
Normal file
28
mlx/compile.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
|
||||
|
||||
// Compile takes a function and returns a new function
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
||||
|
||||
/** Globally disable compilation.
|
||||
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
||||
* be used to disable compilation.
|
||||
*/
|
||||
void disable_compile();
|
||||
|
||||
/** Globally enable compilation.
|
||||
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
|
||||
*/
|
||||
void enable_compile();
|
||||
|
||||
/** Set the compiler mode to the given value. */
|
||||
void set_compile_mode(CompileMode mode);
|
||||
} // namespace mlx::core
|
128
mlx/fast.cpp
Normal file
128
mlx/fast.cpp
Normal file
@@ -0,0 +1,128 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/transforms.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
std::vector<array> Custom::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
|
||||
std::vector<array> vjp_outs;
|
||||
for (int i = 0, j = 0; i < vjps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
vjp_outs.push_back(vjps[i]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
return vjp_outs;
|
||||
}
|
||||
|
||||
std::vector<array> Custom::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
|
||||
std::vector<array> jvp_outs;
|
||||
for (int i = 0, j = 0; i < jvps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
jvp_outs.push_back(jvps[i]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
return jvp_outs;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto outputs = mlx::core::vmap(fallback_, axes)(inputs);
|
||||
auto out_axes = std::vector<int>(outputs.size(), 0);
|
||||
return {outputs, out_axes};
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() != 3) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (traditional && x.shape(-1) != dims) {
|
||||
throw std::invalid_argument(
|
||||
"[rope] Does not support partial traditional application.");
|
||||
}
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, offset, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto t = x.dtype();
|
||||
auto N = x.shape(1) + offset;
|
||||
// Compute sines and cosines
|
||||
auto half_dims = dims / 2;
|
||||
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
|
||||
auto freqs = negative(arange(0, half_dims, t, s), s);
|
||||
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
|
||||
auto theta =
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
|
||||
if (traditional) {
|
||||
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s);
|
||||
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s);
|
||||
std::vector<array> outs;
|
||||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
for (auto& o : outs) {
|
||||
o = expand_dims(o, 3, s);
|
||||
}
|
||||
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)};
|
||||
} else {
|
||||
auto out_s = x.shape();
|
||||
out_s.back() = half_dims;
|
||||
auto x1 = slice(x, {0, 0, 0}, out_s, s);
|
||||
out_s.back() = dims;
|
||||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
|
||||
|
||||
std::vector<array> outs;
|
||||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
if (dims < x.shape(-1)) {
|
||||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
||||
}
|
||||
return std::vector<array>{concatenate(outs, 2, s)};
|
||||
}
|
||||
};
|
||||
// TODO change to condition for using custom prim
|
||||
auto stream = to_stream(s);
|
||||
if (stream.device == Device::gpu && x.shape(-1) == dims) {
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_unique<RoPE>(
|
||||
stream, fallback, dims, traditional, base, scale, offset),
|
||||
{x});
|
||||
}
|
||||
return fallback({x})[0];
|
||||
}
|
||||
|
||||
bool RoPE::is_equivalent(const Primitive& other) const {
|
||||
const RoPE& a_other = static_cast<const RoPE&>(other);
|
||||
return (
|
||||
dims_ == a_other.dims_ && base_ == a_other.base_ &&
|
||||
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
|
||||
offset_ == a_other.offset_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
82
mlx/fast.h
Normal file
82
mlx/fast.h
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
// Custom primitive accepts a fallback function which it uses for
|
||||
// transformations. Transformations are virtual so that derived classes may to
|
||||
// override the default behavior
|
||||
class Custom : public Primitive {
|
||||
public:
|
||||
explicit Custom(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback)
|
||||
: Primitive(stream), fallback_(fallback){};
|
||||
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
virtual std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
virtual std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
};
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */);
|
||||
|
||||
class RoPE : public Custom {
|
||||
public:
|
||||
RoPE(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset)
|
||||
: Custom(stream, fallback),
|
||||
dims_(dims),
|
||||
traditional_(traditional),
|
||||
base_(base),
|
||||
scale_(scale),
|
||||
offset_(offset){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(RoPE)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
int dims_;
|
||||
bool traditional_;
|
||||
float base_;
|
||||
float scale_;
|
||||
int offset_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -12,27 +12,24 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
struct NodeNamer {
|
||||
std::unordered_map<std::uintptr_t, std::string> names;
|
||||
|
||||
std::string get_name(const array& x) {
|
||||
auto it = names.find(x.id());
|
||||
if (it == names.end()) {
|
||||
// Get the next name in the sequence
|
||||
// [A, B, ..., Z, AA, AB, ...]
|
||||
std::vector<char> letters;
|
||||
auto var_num = names.size() + 1;
|
||||
while (var_num > 0) {
|
||||
letters.push_back('A' + (var_num - 1) % 26);
|
||||
var_num = (var_num - 1) / 26;
|
||||
}
|
||||
std::string name(letters.rbegin(), letters.rend());
|
||||
names.insert({x.id(), name});
|
||||
return name;
|
||||
const std::string& NodeNamer::get_name(const array& x) {
|
||||
auto it = names.find(x.id());
|
||||
if (it == names.end()) {
|
||||
// Get the next name in the sequence
|
||||
// [A, B, ..., Z, AA, AB, ...]
|
||||
std::vector<char> letters;
|
||||
auto var_num = names.size() + 1;
|
||||
while (var_num > 0) {
|
||||
letters.push_back('A' + (var_num - 1) % 26);
|
||||
var_num = (var_num - 1) / 26;
|
||||
}
|
||||
return it->second;
|
||||
std::string name(letters.rbegin(), letters.rend());
|
||||
names.insert({x.id(), name});
|
||||
|
||||
return get_name(x);
|
||||
}
|
||||
};
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void depth_first_traversal(
|
||||
std::function<void(array)> callback,
|
||||
|
@@ -6,6 +6,12 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
struct NodeNamer {
|
||||
std::unordered_map<std::uintptr_t, std::string> names;
|
||||
|
||||
const std::string& get_name(const array& x);
|
||||
};
|
||||
|
||||
void print_graph(std::ostream& os, const std::vector<array>& outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
|
29
mlx/io.h
29
mlx/io.h
@@ -10,6 +10,14 @@
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core {
|
||||
using GGUFMetaData =
|
||||
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
|
||||
using GGUFLoad = std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, GGUFMetaData>>;
|
||||
using SafetensorsLoad = std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string>>;
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
||||
@@ -24,32 +32,29 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||
array load(const std::string& file, StreamOrDevice s = {});
|
||||
|
||||
/** Load array map from .safetensors file format */
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
SafetensorsLoad load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
StreamOrDevice s = {});
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
SafetensorsLoad load_safetensors(
|
||||
const std::string& file,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
void save_safetensors(
|
||||
std::shared_ptr<io::Writer> in_stream,
|
||||
std::unordered_map<std::string, array>);
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string> metadata = {});
|
||||
void save_safetensors(
|
||||
const std::string& file,
|
||||
std::unordered_map<std::string, array>);
|
||||
|
||||
using MetaData =
|
||||
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string> metadata = {});
|
||||
|
||||
/** Load array map and metadata from .gguf file format */
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
load_gguf(const std::string& file, StreamOrDevice s = {});
|
||||
|
||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
|
||||
|
||||
void save_gguf(
|
||||
std::string file,
|
||||
std::unordered_map<std::string, array> array_map,
|
||||
std::unordered_map<std::string, MetaData> meta_data = {});
|
||||
std::unordered_map<std::string, GGUFMetaData> meta_data = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -11,14 +11,8 @@ MESSAGE(STATUS "Downloading json")
|
||||
FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||
FetchContent_MakeAvailable(json)
|
||||
target_include_directories(
|
||||
mlx PUBLIC
|
||||
mlx PRIVATE
|
||||
$<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
|
||||
$<INSTALL_INTERFACE:include/json>
|
||||
)
|
||||
install(
|
||||
DIRECTORY ${json_SOURCE_DIR}/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/json
|
||||
COMPONENT json_source
|
||||
)
|
||||
|
||||
MESSAGE(STATUS "Downloading gguflib")
|
||||
@@ -28,14 +22,8 @@ FetchContent_Declare(gguflib
|
||||
)
|
||||
FetchContent_MakeAvailable(gguflib)
|
||||
target_include_directories(
|
||||
mlx PUBLIC
|
||||
mlx PRIVATE
|
||||
$<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>
|
||||
$<INSTALL_INTERFACE:include/gguflib>
|
||||
)
|
||||
install(
|
||||
DIRECTORY ${gguflib_SOURCE_DIR}/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/gguflib
|
||||
COMPONENT gguflib_source
|
||||
)
|
||||
|
||||
add_library(
|
||||
|
@@ -82,7 +82,7 @@ void set_mx_value_from_gguf(
|
||||
gguf_ctx* ctx,
|
||||
uint32_t type,
|
||||
gguf_value* val,
|
||||
MetaData& value) {
|
||||
GGUFMetaData& value) {
|
||||
switch (type) {
|
||||
case GGUF_VALUE_TYPE_UINT8:
|
||||
value = array(val->uint8, uint8);
|
||||
@@ -191,12 +191,12 @@ void set_mx_value_from_gguf(
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, MetaData> load_metadata(gguf_ctx* ctx) {
|
||||
std::unordered_map<std::string, MetaData> metadata;
|
||||
std::unordered_map<std::string, GGUFMetaData> load_metadata(gguf_ctx* ctx) {
|
||||
std::unordered_map<std::string, GGUFMetaData> metadata;
|
||||
gguf_key key;
|
||||
while (gguf_get_key(ctx, &key)) {
|
||||
std::string key_name = std::string(key.name, key.namelen);
|
||||
auto& val = metadata.insert({key_name, MetaData{}}).first->second;
|
||||
auto& val = metadata.insert({key_name, GGUFMetaData{}}).first->second;
|
||||
set_mx_value_from_gguf(ctx, key.type, key.val, val);
|
||||
}
|
||||
return metadata;
|
||||
@@ -230,10 +230,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
||||
return array_map;
|
||||
}
|
||||
|
||||
std::pair<
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, MetaData>>
|
||||
load_gguf(const std::string& file, StreamOrDevice s) {
|
||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
|
||||
gguf_ctx* ctx = gguf_open(file.c_str());
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||
@@ -280,7 +277,7 @@ void append_kv_array(
|
||||
void save_gguf(
|
||||
std::string file,
|
||||
std::unordered_map<std::string, array> array_map,
|
||||
std::unordered_map<std::string, MetaData> metadata /* = {} */) {
|
||||
std::unordered_map<std::string, GGUFMetaData> metadata /* = {} */) {
|
||||
// Add .gguf to file name if it is not there
|
||||
if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") {
|
||||
file += ".gguf";
|
||||
|
@@ -93,7 +93,7 @@ Dtype dtype_from_safetensor_str(std::string str) {
|
||||
}
|
||||
|
||||
/** Load array from reader in safetensor format */
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
SafetensorsLoad load_safetensors(
|
||||
std::shared_ptr<io::Reader> in_stream,
|
||||
StreamOrDevice s) {
|
||||
////////////////////////////////////////////////////////
|
||||
@@ -121,9 +121,12 @@ std::unordered_map<std::string, array> load_safetensors(
|
||||
size_t offset = jsonHeaderLength + 8;
|
||||
// Load the arrays using metadata
|
||||
std::unordered_map<std::string, array> res;
|
||||
std::unordered_map<std::string, std::string> metadata_map;
|
||||
for (const auto& item : metadata.items()) {
|
||||
if (item.key() == "__metadata__") {
|
||||
// ignore metadata for now
|
||||
for (const auto& meta_item : item.value().items()) {
|
||||
metadata_map.insert({meta_item.key(), meta_item.value()});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
std::string dtype = item.value().at("dtype");
|
||||
@@ -138,19 +141,18 @@ std::unordered_map<std::string, array> load_safetensors(
|
||||
std::vector<array>{});
|
||||
res.insert({item.key(), loaded_array});
|
||||
}
|
||||
return res;
|
||||
return {res, metadata_map};
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, array> load_safetensors(
|
||||
const std::string& file,
|
||||
StreamOrDevice s) {
|
||||
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
|
||||
return load_safetensors(std::make_shared<io::FileReader>(file), s);
|
||||
}
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save_safetensors(
|
||||
std::shared_ptr<io::Writer> out_stream,
|
||||
std::unordered_map<std::string, array> a) {
|
||||
std::unordered_map<std::string, array> a,
|
||||
std::unordered_map<std::string, std::string> metadata /* = {} */) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Check file
|
||||
if (!out_stream->good() || !out_stream->is_open()) {
|
||||
@@ -161,9 +163,11 @@ void save_safetensors(
|
||||
////////////////////////////////////////////////////////
|
||||
// Check array map
|
||||
json parent;
|
||||
parent["__metadata__"] = json::object({
|
||||
{"format", "mlx"},
|
||||
});
|
||||
json _metadata;
|
||||
for (auto& [key, value] : metadata) {
|
||||
_metadata[key] = value;
|
||||
}
|
||||
parent["__metadata__"] = _metadata;
|
||||
size_t offset = 0;
|
||||
for (auto& [key, arr] : a) {
|
||||
arr.eval();
|
||||
@@ -204,7 +208,8 @@ void save_safetensors(
|
||||
|
||||
void save_safetensors(
|
||||
const std::string& file_,
|
||||
std::unordered_map<std::string, array> a) {
|
||||
std::unordered_map<std::string, array> a,
|
||||
std::unordered_map<std::string, std::string> metadata /* = {} */) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
@@ -214,7 +219,7 @@ void save_safetensors(
|
||||
file += ".safetensors";
|
||||
|
||||
// Serialize array
|
||||
save_safetensors(std::make_shared<io::FileWriter>(file), a);
|
||||
save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -4,7 +4,9 @@
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/io.h"
|
||||
#include "mlx/linalg.h"
|
||||
|
42
mlx/ops.cpp
42
mlx/ops.cpp
@@ -59,16 +59,6 @@ Dtype at_least_float(const Dtype& d) {
|
||||
|
||||
} // namespace
|
||||
|
||||
Stream to_stream(StreamOrDevice s) {
|
||||
if (std::holds_alternative<std::monostate>(s)) {
|
||||
return default_stream(default_device());
|
||||
} else if (std::holds_alternative<Device>(s)) {
|
||||
return default_stream(std::get<Device>(s));
|
||||
} else {
|
||||
return std::get<Stream>(s);
|
||||
}
|
||||
}
|
||||
|
||||
array arange(
|
||||
double start,
|
||||
double stop,
|
||||
@@ -148,6 +138,9 @@ array linspace(
|
||||
msg << "[linspace] number of samples, " << num << ", must be non-negative.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (num == 1) {
|
||||
return astype(array({start}), dtype, to_stream(s));
|
||||
}
|
||||
array sequence = arange(0, num, float32, to_stream(s));
|
||||
float step = (stop - start) / (num - 1);
|
||||
return astype(
|
||||
@@ -629,6 +622,13 @@ std::vector<array> split(
|
||||
|
||||
std::vector<array>
|
||||
split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) {
|
||||
auto ax = axis < 0 ? axis + a.ndim() : axis;
|
||||
if (ax < 0 || ax >= a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid axis " << axis << " passed to split"
|
||||
<< " for array with shape " << a.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto q_and_r = std::ldiv(a.shape(axis), num_splits);
|
||||
if (q_and_r.rem) {
|
||||
std::ostringstream msg;
|
||||
@@ -1316,6 +1316,15 @@ array mean(
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
int ndim = a.ndim();
|
||||
for (int axis : axes) {
|
||||
if (axis < -ndim || axis >= ndim) {
|
||||
std::ostringstream msg;
|
||||
msg << "[mean] axis " << axis << " is out of bounds for array with "
|
||||
<< ndim << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
auto nelements = compute_number_of_elements(a, axes);
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return multiply(sum(a, axes, keepdims, s), array(1.0 / nelements, dtype), s);
|
||||
@@ -1352,7 +1361,7 @@ array var(
|
||||
|
||||
if (ddof != 0) {
|
||||
auto nelements = compute_number_of_elements(a, axes);
|
||||
float factor = nelements / (nelements - ddof);
|
||||
auto factor = nelements / static_cast<float>(std::max(nelements - ddof, 0));
|
||||
v = multiply(v, array(factor, dtype), s);
|
||||
}
|
||||
|
||||
@@ -1779,10 +1788,6 @@ array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
inputs);
|
||||
}
|
||||
array operator&&(const array& a, const array& b) {
|
||||
// check if a and b are bool arrays
|
||||
if (a.dtype() != bool_ || b.dtype() != bool_) {
|
||||
throw std::invalid_argument("[operator&&] only supported for bool arrays.");
|
||||
}
|
||||
return logical_and(a, b);
|
||||
}
|
||||
|
||||
@@ -1797,11 +1802,6 @@ array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
inputs);
|
||||
}
|
||||
array operator||(const array& a, const array& b) {
|
||||
// check if a and b are bool arrays
|
||||
if (a.dtype() != bool_ || b.dtype() != bool_) {
|
||||
throw std::invalid_argument(
|
||||
"[operator||] is only supported for bool arrays.");
|
||||
}
|
||||
return logical_or(a, b);
|
||||
}
|
||||
|
||||
@@ -2830,7 +2830,7 @@ array conv2d(
|
||||
throw std::invalid_argument("[conv2d] Cannot handle groups != 1 yet");
|
||||
}
|
||||
if (dilation.first != 1 || dilation.second != 1) {
|
||||
throw std::invalid_argument("[conv1d] Cannot handle dilation != 1 yet");
|
||||
throw std::invalid_argument("[conv2d] Cannot handle dilation != 1 yet");
|
||||
}
|
||||
|
||||
// Run checks
|
||||
|
@@ -3,18 +3,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
||||
|
||||
Stream to_stream(StreamOrDevice s);
|
||||
|
||||
/** Creation operations */
|
||||
|
||||
/**
|
||||
|
@@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/io/load.h"
|
||||
@@ -451,6 +453,53 @@ class Ceil : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Compiled : public Primitive {
|
||||
public:
|
||||
/*
|
||||
* The inputs, outputs and tape are either tracers or constants.
|
||||
* - The tape should not contain the inputs, but it should contain the
|
||||
* outputs.
|
||||
* - The tape should also have only one array per primitive for multi-output
|
||||
* primitives.
|
||||
* - The constant_ids contains ids of arrays in the input list that are safe
|
||||
* to treat as scalar constants.
|
||||
*/
|
||||
explicit Compiled(
|
||||
Stream stream,
|
||||
std::vector<array> inputs,
|
||||
std::vector<array> outputs,
|
||||
std::vector<array> tape,
|
||||
std::unordered_set<uintptr_t> constant_ids);
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
void print(std::ostream& os) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
std::string metal_lib_name() const {
|
||||
return kernel_lib_;
|
||||
}
|
||||
std::string metal_lib_source() const {
|
||||
return kernel_source_;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<array> inputs_;
|
||||
const std::vector<array> outputs_;
|
||||
const std::vector<array> tape_;
|
||||
const std::unordered_set<uintptr_t> constant_ids_;
|
||||
|
||||
std::string kernel_lib_;
|
||||
std::string kernel_source_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& out);
|
||||
};
|
||||
|
||||
class Concatenate : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Concatenate(Stream stream, int axis)
|
||||
@@ -667,9 +716,16 @@ class Equal : public UnaryPrimitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Equal)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
if (equal_nan_) {
|
||||
os << "NanEqual";
|
||||
} else {
|
||||
os << "Equal";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
bool equal_nan_;
|
||||
@@ -903,9 +959,22 @@ class Log : public UnaryPrimitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Log)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
switch (base_) {
|
||||
case e:
|
||||
os << "Log";
|
||||
break;
|
||||
case two:
|
||||
os << "Log2";
|
||||
break;
|
||||
case ten:
|
||||
os << "Log10";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Base base_;
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
@@ -1552,9 +1621,16 @@ class Sqrt : public UnaryPrimitive {
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Sqrt)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
void print(std::ostream& os) override {
|
||||
if (recip_) {
|
||||
os << "Rsqrt";
|
||||
} else {
|
||||
os << "Sqrt";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
bool recip_;
|
||||
|
@@ -153,14 +153,23 @@ array uniform(
|
||||
array normal(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
const float loc /* = 0.0 */,
|
||||
const float scale /* = 1.0 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto stream = to_stream(s);
|
||||
auto low = array(std::nextafter(-1.0f, 0.0f), dtype);
|
||||
auto high = array(1.0f, dtype);
|
||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||
return multiply(
|
||||
array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
|
||||
samples =
|
||||
multiply(array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream);
|
||||
if (scale != 1.0) {
|
||||
samples = multiply(array(scale, dtype), samples, stream);
|
||||
}
|
||||
if (loc != 0.0) {
|
||||
samples = add(array(loc, dtype), samples, stream);
|
||||
}
|
||||
return samples;
|
||||
}
|
||||
|
||||
array randint(
|
||||
|
19
mlx/random.h
19
mlx/random.h
@@ -95,13 +95,30 @@ inline array uniform(
|
||||
array normal(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const float loc,
|
||||
const float scale,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, float32, key, s);
|
||||
return normal(shape, float32, loc, scale, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const Dtype dtype,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, dtype, 0.0, 1.0, key, s);
|
||||
}
|
||||
inline array normal(
|
||||
const std::vector<int>& shape,
|
||||
const std::optional<array>& key = std::nullopt,
|
||||
StreamOrDevice s = {}) {
|
||||
return normal(shape, float32, 0.0, 1.0, key, s);
|
||||
}
|
||||
|
||||
/** Generate integer samples uniformly at random */
|
||||
|
@@ -6,21 +6,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Compile takes a function and returns a new function
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
||||
|
||||
/** Globally disable compilation.
|
||||
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
||||
* be used to disable compilation.
|
||||
*/
|
||||
void disable_compile();
|
||||
|
||||
/** Globally enable compilation.
|
||||
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
|
||||
*/
|
||||
void enable_compile();
|
||||
|
||||
void eval(const std::vector<array>& outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
|
@@ -35,6 +35,16 @@ inline bool operator>(const complex64_t& a, const complex64_t& b) {
|
||||
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
|
||||
}
|
||||
|
||||
inline complex64_t operator%(complex64_t a, complex64_t b) {
|
||||
auto real = a.real() - (b.real() * static_cast<int64_t>(a.real() / b.real()));
|
||||
auto imag = a.imag() - (b.imag() * static_cast<int64_t>(a.imag() / b.imag()));
|
||||
if (real != 0 && (real < 0 != b.real() < 0))
|
||||
real += b.real();
|
||||
if (imag != 0 && (imag < 0 != b.imag() < 0))
|
||||
imag += b.imag();
|
||||
return {real, imag};
|
||||
}
|
||||
|
||||
inline bool operator<=(const complex64_t& a, const complex64_t& b) {
|
||||
return operator>=(b, a);
|
||||
}
|
||||
|
@@ -7,6 +7,16 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
Stream to_stream(StreamOrDevice s) {
|
||||
if (std::holds_alternative<std::monostate>(s)) {
|
||||
return default_stream(default_device());
|
||||
} else if (std::holds_alternative<Device>(s)) {
|
||||
return default_stream(std::get<Device>(s));
|
||||
} else {
|
||||
return std::get<Stream>(s);
|
||||
}
|
||||
}
|
||||
|
||||
void PrintFormatter::print(std::ostream& os, bool val) {
|
||||
if (capitalize_bool) {
|
||||
os << (val ? "True" : "False");
|
||||
|
26
mlx/utils.h
26
mlx/utils.h
@@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <variant>
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "dtype.h"
|
||||
@@ -9,6 +11,30 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
||||
Stream to_stream(StreamOrDevice s);
|
||||
|
||||
struct StreamContext {
|
||||
public:
|
||||
StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) {
|
||||
if (std::holds_alternative<std::monostate>(s)) {
|
||||
throw std::runtime_error(
|
||||
"[StreamContext] Invalid argument, please specify a stream or device.");
|
||||
}
|
||||
auto _s = to_stream(s);
|
||||
set_default_device(_s.device);
|
||||
set_default_stream(_s);
|
||||
}
|
||||
|
||||
~StreamContext() {
|
||||
set_default_device(_stream.device);
|
||||
set_default_stream(_stream);
|
||||
}
|
||||
|
||||
private:
|
||||
Stream _stream;
|
||||
};
|
||||
|
||||
struct PrintFormatter {
|
||||
inline void print(std::ostream& os, bool val);
|
||||
inline void print(std::ostream& os, int16_t val);
|
||||
|
@@ -60,7 +60,7 @@ def normal(
|
||||
"""
|
||||
|
||||
def initializer(a: mx.array) -> mx.array:
|
||||
return std * mx.random.normal(shape=a.shape, dtype=dtype) + mean
|
||||
return mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)
|
||||
|
||||
return initializer
|
||||
|
||||
@@ -184,7 +184,7 @@ def glorot_normal(
|
||||
def initializer(a: mx.array, gain: float = 1.0) -> mx.array:
|
||||
fan_in, fan_out = _calculate_fan_in_fan_out(a)
|
||||
std = gain * math.sqrt(2.0 / (fan_in + fan_out))
|
||||
return mx.random.normal(shape=a.shape, dtype=dtype) * std
|
||||
return mx.random.normal(shape=a.shape, scale=std, dtype=dtype)
|
||||
|
||||
return initializer
|
||||
|
||||
@@ -285,7 +285,7 @@ def he_normal(
|
||||
raise ValueError(f"Invalid mode: {mode}. Valid modes are: fan_in, fan_out")
|
||||
|
||||
std = gain / math.sqrt(fan)
|
||||
return mx.random.normal(shape=a.shape, dtype=dtype) * std
|
||||
return mx.random.normal(shape=a.shape, scale=std, dtype=dtype)
|
||||
|
||||
return initializer
|
||||
|
||||
|
@@ -58,6 +58,7 @@ from mlx.nn.layers.normalization import (
|
||||
LayerNorm,
|
||||
RMSNorm,
|
||||
)
|
||||
from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d
|
||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
from mlx.nn.layers.transformer import (
|
||||
|
@@ -66,6 +66,19 @@ class Module(dict):
|
||||
"""Boolean indicating if the model is in training mode."""
|
||||
return self._training
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""The module's state dictionary
|
||||
|
||||
The module's state dictionary contains any attribute set on the
|
||||
module including parameters in :meth:`Module.parameters`
|
||||
|
||||
Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is
|
||||
a reference to the module's state. Updates to it will be reflected in
|
||||
the original module.
|
||||
"""
|
||||
return self
|
||||
|
||||
def _extra_repr(self):
|
||||
return ""
|
||||
|
||||
@@ -312,7 +325,7 @@ class Module(dict):
|
||||
elif isinstance(current_value, (dict, list)):
|
||||
apply(current_value, new_value)
|
||||
elif isinstance(parameters, list):
|
||||
for i in range(len(dst)):
|
||||
for i in range(len(parameters)):
|
||||
current_value = dst[i]
|
||||
new_value = parameters[i]
|
||||
if isinstance(current_value, mx.array):
|
||||
|
@@ -21,7 +21,7 @@ class Embedding(Module):
|
||||
def __init__(self, num_embeddings: int, dims: int):
|
||||
super().__init__()
|
||||
scale = math.sqrt(1 / dims)
|
||||
self.weight = mx.random.normal((num_embeddings, dims)) * scale
|
||||
self.weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
|
||||
|
308
python/mlx/nn/layers/pooling.py
Normal file
308
python/mlx/nn/layers/pooling.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import operator
|
||||
from itertools import accumulate
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
def _value_or_list(x, n, msg):
|
||||
if isinstance(x, (list, tuple)):
|
||||
if len(x) != n:
|
||||
raise ValueError(msg)
|
||||
return list(x)
|
||||
|
||||
if not isinstance(x, int):
|
||||
raise ValueError(msg)
|
||||
|
||||
return [x] * n
|
||||
|
||||
|
||||
def _sliding_windows(x, window_shape, window_strides):
|
||||
if x.ndim < 3:
|
||||
raise ValueError(
|
||||
f"To extract sliding windows at least 1 spatial dimension "
|
||||
f"(3 total) is needed but the input only has {x.ndim} dimensions."
|
||||
)
|
||||
|
||||
spatial_dims = x.shape[1:-1]
|
||||
if not (len(spatial_dims) == len(window_shape) == len(window_strides)):
|
||||
raise ValueError(
|
||||
f"To extract sliding windows the window shapes and strides must have "
|
||||
f"the same number of spatial dimensions as the signal but the signal "
|
||||
f"has {len(spatial_dims)} dims and the window shape has {len(window_shape)} "
|
||||
f"and strides have {len(window_strides)}."
|
||||
)
|
||||
|
||||
shape = x.shape
|
||||
strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:]
|
||||
|
||||
# Compute the output shape
|
||||
final_shape = [shape[0]]
|
||||
final_shape += [
|
||||
(size - window) // stride + 1
|
||||
for size, window, stride in zip(spatial_dims, window_shape, window_strides)
|
||||
]
|
||||
final_shape += window_shape
|
||||
final_shape += [shape[-1]]
|
||||
|
||||
# Compute the output strides
|
||||
final_strides = strides[:1]
|
||||
final_strides += [
|
||||
og_stride * stride for og_stride, stride in zip(strides[1:-1], window_strides)
|
||||
]
|
||||
final_strides += strides[1:-1]
|
||||
final_strides += strides[-1:] # should always be [1]
|
||||
|
||||
return mx.as_strided(x, final_shape, final_strides)
|
||||
|
||||
|
||||
class _Pool(Module):
|
||||
def __init__(self, pooling_function, kernel_size, stride, padding, padding_value):
|
||||
super().__init__()
|
||||
|
||||
self._pooling_function = pooling_function
|
||||
self._kernel_size = kernel_size
|
||||
self._stride = stride
|
||||
self._padding = padding
|
||||
self._padding_value = padding_value
|
||||
self._axes = tuple(range(-len(self._kernel_size) - 1, -1, 1))
|
||||
|
||||
def _extra_repr(self):
|
||||
ks = tuple(self._kernel_size)
|
||||
st = tuple(self._stride)
|
||||
pd = tuple(p[0] for p in self._padding)
|
||||
|
||||
return f"kernel_size={ks}, stride={st}, padding={pd}"
|
||||
|
||||
def __call__(self, x):
|
||||
if any(p[0] > 0 for p in self._padding):
|
||||
x = mx.pad(x, [(0, 0)] + self._padding + [(0, 0)], self._padding_value)
|
||||
x = _sliding_windows(x, self._kernel_size, self._stride)
|
||||
return self._pooling_function(x, self._axes)
|
||||
|
||||
|
||||
class _Pool1d(_Pool):
|
||||
def __init__(
|
||||
self,
|
||||
pooling_function,
|
||||
padding_value,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Optional[Union[int, Tuple[int]]] = None,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
):
|
||||
class_name = type(self).__name__
|
||||
msg = "[{}] '{}' must be an integer or a tuple containing 1 integer"
|
||||
kernel_size = _value_or_list(
|
||||
kernel_size, 1, msg.format(class_name, "kernel_size")
|
||||
)
|
||||
if stride is not None:
|
||||
stride = _value_or_list(stride, 1, msg.format(class_name, "stride"))
|
||||
else:
|
||||
stride = kernel_size
|
||||
padding = _value_or_list(padding, 1, msg.format(class_name, "padding"))
|
||||
padding = [(p, p) for p in padding]
|
||||
|
||||
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
|
||||
|
||||
|
||||
class _Pool2d(_Pool):
|
||||
def __init__(
|
||||
self,
|
||||
pooling_function,
|
||||
padding_value,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
):
|
||||
class_name = type(self).__name__
|
||||
msg = "[{}] '{}' must be an integer or a tuple containing 2 integers"
|
||||
kernel_size = _value_or_list(
|
||||
kernel_size, 2, msg.format(class_name, "kernel_size")
|
||||
)
|
||||
if stride is not None:
|
||||
stride = _value_or_list(stride, 2, msg.format(class_name, "stride"))
|
||||
else:
|
||||
stride = kernel_size
|
||||
padding = _value_or_list(padding, 2, msg.format(class_name, "padding"))
|
||||
padding = [(p, p) for p in padding]
|
||||
|
||||
super().__init__(pooling_function, kernel_size, stride, padding, padding_value)
|
||||
|
||||
|
||||
class MaxPool1d(_Pool1d):
|
||||
r"""Applies 1-dimensional max pooling.
|
||||
|
||||
Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is
|
||||
:math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given
|
||||
by:
|
||||
|
||||
.. math::
|
||||
\text{out}(N_i, t, C_j) = \max_{m=0, \ldots, k - 1}
|
||||
\text{input}(N_i, \text{stride} \times t + m, C_j),
|
||||
|
||||
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
|
||||
\text{kernel_size}}{\text{stride}}\right\rfloor + 1`.
|
||||
|
||||
Args:
|
||||
kernel_size (int or tuple(int)): The size of the pooling window kernel.
|
||||
stride (int or tuple(int), optional): The stride of the pooling window.
|
||||
Default: ``kernel_size``.
|
||||
padding (int or tuple(int), optional): How much negative infinity
|
||||
padding to apply to the input. The padding amount is applied to
|
||||
both sides of the spatial axis. Default: ``0``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(4, 16, 5))
|
||||
>>> pool = nn.MaxPool1d(kernel_size=2, stride=2)
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
):
|
||||
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
|
||||
|
||||
|
||||
class AvgPool1d(_Pool1d):
|
||||
r"""Applies 1-dimensional average pooling.
|
||||
|
||||
Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is
|
||||
:math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given
|
||||
by:
|
||||
|
||||
.. math::
|
||||
\text{out}(N_i, t, C_j) = \frac{1}{k} \sum_{m=0, \ldots, k - 1}
|
||||
\text{input}(N_i, \text{stride} \times t + m, C_j),
|
||||
|
||||
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
|
||||
\text{kernel_size}}{\text{stride}}\right\rfloor + 1`.
|
||||
|
||||
Args:
|
||||
kernel_size (int or tuple(int)): The size of the pooling window kernel.
|
||||
stride (int or tuple(int), optional): The stride of the pooling window.
|
||||
Default: ``kernel_size``.
|
||||
padding (int or tuple(int), optional): How much zero padding to apply to
|
||||
the input. The padding amount is applied to both sides of the spatial
|
||||
axis. Default: ``0``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(4, 16, 5))
|
||||
>>> pool = nn.AvgPool1d(kernel_size=2, stride=2)
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
):
|
||||
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
||||
|
||||
|
||||
class MaxPool2d(_Pool2d):
|
||||
r"""Applies 2-dimensional max pooling.
|
||||
|
||||
Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is
|
||||
:math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out},
|
||||
W_{out}, C)`, given by:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\text{out}(N_i, h, w, C_j) = & \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
|
||||
& \text{input}(N_i, \text{stride[0]} \times h + m,
|
||||
\text{stride[1]} \times w + n, C_j),
|
||||
\end{aligned}
|
||||
|
||||
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
|
||||
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
|
||||
|
||||
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
|
||||
|
||||
- a single ``int`` -- in which case the same value is used for both the
|
||||
height and width axis;
|
||||
- a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
|
||||
used for the height axis, the second ``int`` for the width axis.
|
||||
|
||||
Args:
|
||||
kernel_size (int or tuple(int, int)): The size of the pooling window.
|
||||
stride (int or tuple(int, int), optional): The stride of the pooling
|
||||
window. Default: ``kernel_size``.
|
||||
padding (int or tuple(int, int), optional): How much negative infinity
|
||||
padding to apply to the input. The padding is applied on both sides
|
||||
of the height and width axis. Default: ``0``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
|
||||
>>> pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
):
|
||||
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
|
||||
|
||||
|
||||
class AvgPool2d(_Pool2d):
|
||||
r"""Applies 2-dimensional average pooling.
|
||||
|
||||
Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is
|
||||
:math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out},
|
||||
W_{out}, C)`, given by:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\text{out}(N_i, h, w, C_j) = & \frac{1}{k_H k_W} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
|
||||
& \text{input}(N_i, \text{stride[0]} \times h + m,
|
||||
\text{stride[1]} \times w + n, C_j),
|
||||
\end{aligned}
|
||||
|
||||
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
|
||||
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
|
||||
|
||||
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
|
||||
|
||||
- a single ``int`` -- in which case the same value is used for both the
|
||||
height and width axis;
|
||||
- a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
|
||||
used for the height axis, the second ``int`` for the width axis.
|
||||
|
||||
Args:
|
||||
kernel_size (int or tuple(int, int)): The size of the pooling window.
|
||||
stride (int or tuple(int, int), optional): The stride of the pooling
|
||||
window. Default: ``kernel_size``.
|
||||
padding (int or tuple(int, int), optional): How much zero
|
||||
padding to apply to the input. The padding is applied on both sides
|
||||
of the height and width axis. Default: ``0``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn.layers as nn
|
||||
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
|
||||
>>> pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
):
|
||||
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
@@ -20,20 +20,13 @@ class RoPE(Module):
|
||||
Args:
|
||||
dims (int): The feature dimensions to be rotated. If the input feature
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool, optional): If set to True choose the traditional
|
||||
traditional (bool, optional): If set to ``True`` choose the traditional
|
||||
implementation which is slightly less efficient. Default: ``False``.
|
||||
base (float, optional): The base used to compute angular frequency for
|
||||
each dimension in the positional encodings. Default: ``10000``.
|
||||
scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
|
||||
|
||||
Attributes:
|
||||
_cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values.
|
||||
_cos_sin_theta_value (tuple): Cached cosine and sine values.
|
||||
"""
|
||||
|
||||
_cos_sin_theta_key = None
|
||||
_cos_sin_theta_value = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
@@ -50,69 +43,18 @@ class RoPE(Module):
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}"
|
||||
|
||||
def _compute_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., : self.dims // 2]
|
||||
x2 = x[..., self.dims // 2 : self.dims]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
|
||||
else:
|
||||
rx = mx.concatenate([rx1, rx2], axis=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
raise NotImplementedError(
|
||||
"RoPE doesn't implement partial traditional application"
|
||||
)
|
||||
|
||||
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype
|
||||
x = mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
base=self.base,
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return mx.reshape(rx, shape)
|
||||
|
||||
@classmethod
|
||||
def create_cos_sin_theta(
|
||||
cls,
|
||||
N: int,
|
||||
D: int,
|
||||
offset: int = 0,
|
||||
base: float = 10000,
|
||||
scale: float = 1.0,
|
||||
dtype=mx.float32,
|
||||
):
|
||||
if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key:
|
||||
half_D = D // 2
|
||||
positions = mx.arange(offset, N, dtype=dtype) * scale
|
||||
freqs = mx.exp(
|
||||
-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)
|
||||
)
|
||||
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
||||
cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype)
|
||||
cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta))
|
||||
return cls._cos_sin_theta_value
|
||||
return mx.reshape(x, shape)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(Module):
|
||||
|
@@ -117,6 +117,7 @@ def cross_entropy(
|
||||
def binary_cross_entropy(
|
||||
inputs: mx.array,
|
||||
targets: mx.array,
|
||||
weights: mx.array = None,
|
||||
with_logits: bool = True,
|
||||
reduction: Reduction = "mean",
|
||||
) -> mx.array:
|
||||
@@ -128,6 +129,7 @@ def binary_cross_entropy(
|
||||
``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities.
|
||||
targets (array): The binary target values in {0, 1}.
|
||||
with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``.
|
||||
weights (array, optional): Optional weights for each target. Default: ``None``.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
|
||||
|
||||
@@ -159,6 +161,15 @@ def binary_cross_entropy(
|
||||
else:
|
||||
loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs))
|
||||
|
||||
# Apply weights if provided
|
||||
if weights is not None:
|
||||
if weights.shape != loss.shape:
|
||||
raise ValueError(
|
||||
f"Weights with shape {weights.shape} is not the same as "
|
||||
f"output loss with shape {loss.shape}."
|
||||
)
|
||||
loss *= weights
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
@@ -533,3 +544,58 @@ def cosine_similarity_loss(
|
||||
loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps)
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def margin_ranking_loss(
|
||||
inputs1: mx.array,
|
||||
inputs2: mx.array,
|
||||
targets: mx.array,
|
||||
margin: float = 0.0,
|
||||
reduction: Reduction = "none",
|
||||
) -> mx.array:
|
||||
r"""
|
||||
Calculate the margin ranking loss that loss given inputs :math:`x_1`, :math:`x_2` and a label
|
||||
:math:`y` (containing 1 or -1).
|
||||
|
||||
The loss is given by:
|
||||
|
||||
.. math::
|
||||
\text{loss} = \max (0, -y * (x_1 - x_2) + \text{margin})
|
||||
|
||||
Where :math:`y` represents ``targets``, :math:`x_1` represents ``inputs1`` and :math:`x_2`
|
||||
represents ``inputs2``.
|
||||
|
||||
Args:
|
||||
inputs1 (array): Scores for the first input.
|
||||
inputs2 (array): Scores for the second input.
|
||||
targets (array): Labels indicating whether samples in ``inputs1`` should be ranked higher
|
||||
than samples in ``inputs2``. Values should be 1 or -1.
|
||||
margin (float, optional): The margin by which the scores should be separated.
|
||||
Default: ``0.0``.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
array: The computed margin ranking loss.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>> targets = mx.array([1, 1, -1])
|
||||
>>> inputs1 = mx.array([-0.573409, -0.765166, -0.0638])
|
||||
>>> inputs2 = mx.array([0.75596, 0.225763, 0.256995])
|
||||
>>> loss = nn.losses.margin_ranking_loss(inputs1, inputs2, targets)
|
||||
>>> loss
|
||||
array(0.773433, dtype=float32)
|
||||
"""
|
||||
if not (inputs1.shape == inputs2.shape == targets.shape):
|
||||
raise ValueError(
|
||||
f"The shapes of the arguments do not match. The provided shapes are "
|
||||
f"inputs1.shape={inputs1.shape}, inputs2.shape={inputs2.shape}, and "
|
||||
f"targets.shape={targets.shape}."
|
||||
)
|
||||
|
||||
differences = inputs1 - inputs2
|
||||
loss = mx.maximum(0, -targets * differences + margin)
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user