mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 16:13:52 +08:00
Compare commits
139 Commits
async_all_
...
v0.17.3
Author | SHA1 | Date | |
---|---|---|---|
![]() |
d0c58841d1 | ||
![]() |
881f09b2e2 | ||
![]() |
8b30acd7eb | ||
![]() |
02efb310ca | ||
![]() |
e7e59c6f05 | ||
![]() |
3ae6aabe9f | ||
![]() |
dc627dcb5e | ||
![]() |
efeb9c0f02 | ||
![]() |
ba3e913c7a | ||
![]() |
7cca1727af | ||
![]() |
11371fe251 | ||
![]() |
41c603d48a | ||
![]() |
969337345f | ||
![]() |
9592766939 | ||
![]() |
58dca7d846 | ||
![]() |
0d302cd25b | ||
![]() |
da691257ec | ||
![]() |
1600092e92 | ||
![]() |
dba2bd1105 | ||
![]() |
28be4de7c2 | ||
![]() |
a6c3b38fba | ||
![]() |
fcb65a3897 | ||
![]() |
4e22a1dffe | ||
![]() |
291cf40aca | ||
![]() |
bd47e1f066 | ||
![]() |
e6b223df5f | ||
![]() |
e64349bbdd | ||
![]() |
cdb59faea6 | ||
![]() |
1d94ac3f90 | ||
![]() |
5f7d19d1f5 | ||
![]() |
2fdf9eb535 | ||
![]() |
860d3a50d7 | ||
![]() |
d1183821a7 | ||
![]() |
8081df79be | ||
![]() |
64bec4fad7 | ||
![]() |
b96e105244 | ||
![]() |
3b4d5484c7 | ||
![]() |
684e11c664 | ||
![]() |
b57a52813b | ||
![]() |
da8deb2b62 | ||
![]() |
98b6ce3460 | ||
![]() |
f9e00efe31 | ||
![]() |
0fd2a1f4b0 | ||
![]() |
df3233454d | ||
![]() |
82db84b899 | ||
![]() |
8ae751d3da | ||
![]() |
d40e76809f | ||
![]() |
bb1b76d9dc | ||
![]() |
9d26441224 | ||
![]() |
f12f24a77c | ||
![]() |
ae5b5cabfd | ||
![]() |
d0630ffe8c | ||
![]() |
99bb7d3a58 | ||
![]() |
63ae767232 | ||
![]() |
eaaea02010 | ||
![]() |
a098bc92e0 | ||
![]() |
1086dc4db0 | ||
![]() |
19fb69e2ed | ||
![]() |
9231617eb3 | ||
![]() |
32668a7317 | ||
![]() |
780c197f95 | ||
![]() |
eb8819e91e | ||
![]() |
30bbea2f08 | ||
![]() |
635ccd9e25 | ||
![]() |
8c9f0278b9 | ||
![]() |
58d0e199e1 | ||
![]() |
10b5835501 | ||
![]() |
6c8dd307eb | ||
![]() |
43ffdab172 | ||
![]() |
40b6d67333 | ||
![]() |
c52d1600f0 | ||
![]() |
aa1d6cadad | ||
![]() |
6e06e3a904 | ||
![]() |
8cfb9fc0b8 | ||
![]() |
7b456fd2c0 | ||
![]() |
e9e53856d2 | ||
![]() |
5029894662 | ||
![]() |
baf9fa5f42 | ||
![]() |
7f914365fd | ||
![]() |
ebd7135b50 | ||
![]() |
50eff6a10a | ||
![]() |
c34a5ae7f7 | ||
![]() |
e2aa6ec8ae | ||
![]() |
6768c6a54a | ||
![]() |
6307d166eb | ||
![]() |
1fba87b0df | ||
![]() |
df124e018a | ||
![]() |
2f83d6e4b7 | ||
![]() |
987785d8d7 | ||
![]() |
8c01a7893b | ||
![]() |
218047c75a | ||
![]() |
d0da74209b | ||
![]() |
5c1fa64fb0 | ||
![]() |
a3c287354f | ||
![]() |
03cf033f82 | ||
![]() |
bdb36c9a63 | ||
![]() |
20bb301195 | ||
![]() |
d6383a1c6a | ||
![]() |
b05bcfd27f | ||
![]() |
2615660e62 | ||
![]() |
5b0af4cdb1 | ||
![]() |
8c2e15e6c8 | ||
![]() |
56c8a33439 | ||
![]() |
4eef1e8a3e | ||
![]() |
95d11bda06 | ||
![]() |
af9079cc1f | ||
![]() |
2d6cd47713 | ||
![]() |
fe3167d7ea | ||
![]() |
31e134be35 | ||
![]() |
e84ba8056d | ||
![]() |
f20e97b092 | ||
![]() |
934683088e | ||
![]() |
de2b9e7d0a | ||
![]() |
dd7d8e5e29 | ||
![]() |
df964132fb | ||
![]() |
709ccc6800 | ||
![]() |
cf236fc390 | ||
![]() |
27d70c7d9d | ||
![]() |
0e585b4409 | ||
![]() |
0163a8e57a | ||
![]() |
578842954c | ||
![]() |
496315fe1d | ||
![]() |
0fe6895893 | ||
![]() |
0b7d71fd2f | ||
![]() |
83b11bc58d | ||
![]() |
375a8bbdcc | ||
![]() |
ea9090bbc4 | ||
![]() |
81def6ac76 | ||
![]() |
3de8ce3f3c | ||
![]() |
4d485fca24 | ||
![]() |
1865299a30 | ||
![]() |
3576b547c5 | ||
![]() |
079882495d | ||
![]() |
ab977109db | ||
![]() |
fd1c08137b | ||
![]() |
76b6cece46 | ||
![]() |
9f0df51f8d | ||
![]() |
e7a2a3dcd1 | ||
![]() |
a87ef5bfc1 |
@@ -31,19 +31,24 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install nanobind==2.1.0
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py develop
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
@@ -52,7 +57,9 @@ jobs:
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j `nproc`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: ./build/tests/tests
|
||||
@@ -76,7 +83,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install nanobind==2.1.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
@@ -85,11 +92,12 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
@@ -111,7 +119,7 @@ jobs:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source env/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j
|
||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: |
|
||||
@@ -121,8 +129,23 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
|
||||
make -j
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
@@ -144,11 +167,12 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@<< parameters.python_version >>
|
||||
brew install openmpi
|
||||
python<< parameters.python_version >> -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install nanobind==2.1.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
@@ -158,19 +182,20 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
<< parameters.build_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
@@ -212,18 +237,19 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install nanobind==2.1.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python -m build --wheel
|
||||
auditwheel show dist/*
|
||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||
@@ -244,7 +270,7 @@ workflows:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
- linux_build_and_test
|
||||
|
||||
build_pypi_release:
|
||||
@@ -279,7 +305,7 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
@@ -303,7 +329,7 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
linux_test_release:
|
||||
when:
|
||||
|
2
.github/workflows/pull_request.yml
vendored
2
.github/workflows/pull_request.yml
vendored
@@ -17,4 +17,4 @@ jobs:
|
||||
pip install pre-commit black isort clang-format
|
||||
- name: Run lint
|
||||
run: |
|
||||
pre-commit run --all-files
|
||||
pre-commit run --all-files
|
||||
|
@@ -1,11 +1,11 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v18.1.4
|
||||
rev: v18.1.8
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.4.2
|
||||
rev: 24.8.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
|
@@ -10,13 +10,15 @@ MLX was developed with contributions from the following individuals:
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.14.0)
|
||||
set(MLX_VERSION 0.17.3)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@@ -83,24 +83,21 @@ elseif (MLX_BUILD_METAL)
|
||||
OUTPUT_VARIABLE MACOS_VERSION
|
||||
COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||
set(MLX_METAL_VERSION METAL_3_1)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||
set(MLX_METAL_VERSION METAL_3_0)
|
||||
else()
|
||||
if (${MACOS_VERSION} LESS 14.0)
|
||||
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
||||
endif()
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
|
||||
# Get the metal version
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION
|
||||
COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
FetchContent_Declare(
|
||||
metal_cpp
|
||||
URL ${METAL_CPP_URL}
|
||||
PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
@@ -115,7 +112,7 @@ elseif (MLX_BUILD_METAL)
|
||||
${FOUNDATION_LIB}
|
||||
${QUARTZ_LIB})
|
||||
|
||||
add_compile_definitions(${MLX_METAL_VERSION})
|
||||
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_CPU)
|
||||
@@ -169,7 +166,26 @@ endif()
|
||||
|
||||
find_package(MPI)
|
||||
if (MPI_FOUND)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "mpirun --version"
|
||||
OUTPUT_VARIABLE MPI_VERSION
|
||||
ERROR_QUIET
|
||||
)
|
||||
if (${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
elseif (MPI_VERSION STREQUAL "")
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
WARNING
|
||||
"MPI found but mpirun is not available. Building without MPI."
|
||||
)
|
||||
else()
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
WARNING
|
||||
"MPI which is not OpenMPI found. Building without MPI."
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
@@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
|
||||
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x
|
||||
for _ in range(100):
|
||||
return torch.nn.functional.mish(y)
|
||||
y = torch.nn.functional.mish(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@@ -283,6 +283,14 @@ def topk(axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step_function(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.where(y < 0, 0, 1)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def selu(x):
|
||||
y = x
|
||||
@@ -446,5 +454,11 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "topk":
|
||||
print(bench(topk, axis, x))
|
||||
|
||||
elif args.benchmark == "step":
|
||||
print(bench(step_function, x))
|
||||
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||
|
@@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs):
|
||||
result = run(*args, capture_output=True, **kwargs)
|
||||
return float(result.stdout)
|
||||
except ValueError:
|
||||
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
|
||||
raise ValueError(
|
||||
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
|
||||
)
|
||||
|
||||
|
||||
def compare(args):
|
||||
|
@@ -9,7 +9,6 @@ from time_utils import time_fn
|
||||
|
||||
|
||||
def bench_gelu():
|
||||
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
@@ -51,7 +50,6 @@ def bench_gelu():
|
||||
|
||||
|
||||
def bench_layernorm():
|
||||
|
||||
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
mx.eval(weight, bias)
|
||||
|
@@ -54,7 +54,6 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
|
135
benchmarks/python/conv_transpose_bench.py
Normal file
135
benchmarks/python/conv_transpose_bench.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_transpose_2D
|
||||
|
||||
|
||||
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_transpose_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv_transpose2d(
|
||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.conv_transpose2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
66
benchmarks/python/distributed_bench.py
Normal file
66
benchmarks/python/distributed_bench.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Run with:
|
||||
mpirun -n 2 python /path/to/distributed_bench.py
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def time_fn(fn, *args, **kwargs):
|
||||
msg = kwargs.pop("msg", None)
|
||||
world = mx.distributed.init()
|
||||
if world.rank() == 0:
|
||||
if msg:
|
||||
print(f"Timing {msg} ...", end=" ")
|
||||
else:
|
||||
print(f"Timing {fn.__name__} ...", end=" ")
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(*args, **kwargs))
|
||||
|
||||
num_iters = 100
|
||||
tic = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
x = mx.eval(fn(*args, **kwargs))
|
||||
toc = time.perf_counter()
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
if world.rank() == 0:
|
||||
print(f"{msec:.5f} msec")
|
||||
|
||||
|
||||
def time_all_sum():
|
||||
shape = (4096,)
|
||||
x = mx.random.uniform(shape=shape)
|
||||
mx.eval(x)
|
||||
|
||||
def sine(x):
|
||||
for _ in range(20):
|
||||
x = mx.sin(x)
|
||||
return x
|
||||
|
||||
time_fn(sine, x)
|
||||
|
||||
def all_sum_plain(x):
|
||||
for _ in range(20):
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
time_fn(all_sum_plain, x)
|
||||
|
||||
def all_sum_with_sine(x):
|
||||
for _ in range(20):
|
||||
x = mx.sin(x)
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
time_fn(all_sum_with_sine, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_all_sum()
|
84
benchmarks/python/einsum_bench.py
Normal file
84
benchmarks/python/einsum_bench.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def timeit(fn, its=100, args=[]):
|
||||
for _ in range(5):
|
||||
fn(*args)
|
||||
tic = time.perf_counter()
|
||||
for _ in range(its):
|
||||
fn(*args)
|
||||
toc = time.perf_counter()
|
||||
return 1e3 * (toc - tic) / its
|
||||
|
||||
|
||||
def time_little_einsum_path():
|
||||
subscripts = "ik,kj->ij"
|
||||
x = mx.ones((32, 32))
|
||||
y = mx.ones((32, 32))
|
||||
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
|
||||
|
||||
x = np.array(x)
|
||||
y = np.array(y)
|
||||
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
|
||||
print("Timing little einsum path...")
|
||||
print(f"MLX ... {mx_time:.3f} ms")
|
||||
print(f"NumPy... {np_time:.3f} ms")
|
||||
|
||||
|
||||
def time_big_einsum_path():
|
||||
chars = list("abcdefgh")
|
||||
char_to_dim = {c: v for v, c in enumerate(chars)}
|
||||
|
||||
num_inputs = 10
|
||||
inputs = []
|
||||
subscripts = []
|
||||
for _ in range(num_inputs):
|
||||
subscript = np.random.choice(chars, size=5, replace=False).tolist()
|
||||
subscripts.append("".join(subscript))
|
||||
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
|
||||
subscripts = ",".join(subscripts)
|
||||
|
||||
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
|
||||
|
||||
inputs = [mx.array(x) for x in inputs]
|
||||
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
|
||||
print("Timing big einsum path...")
|
||||
print(f"MLX ... {mx_time:.3f} ms")
|
||||
print(f"NumPy... {np_time:.3f} ms")
|
||||
|
||||
|
||||
def time_attention():
|
||||
def regular_attention(x):
|
||||
# shape [batch, sequence, num_heads, head_dim]
|
||||
queries, keys, values = x, x, x
|
||||
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
|
||||
mx.eval(output)
|
||||
|
||||
def einsum_attention(x):
|
||||
# shape [batch, sequence, num_heads, head_dim]
|
||||
queries, keys, values = x, x, x
|
||||
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
output = mx.einsum("ijtu,iujk->itjk", scores, values)
|
||||
mx.eval(output)
|
||||
|
||||
x = mx.random.uniform(shape=(8, 512, 32, 128))
|
||||
|
||||
regular_time = timeit(regular_attention, args=(x,))
|
||||
ein_time = timeit(einsum_attention, args=(x,))
|
||||
print("Timing einsum attention...")
|
||||
print(f"Regular ... {regular_time:.3f} ms")
|
||||
print(f"Einsum ... {ein_time:.3f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_little_einsum_path()
|
||||
time_big_einsum_path()
|
||||
time_attention()
|
@@ -3,6 +3,8 @@
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import sympy
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
@@ -16,41 +18,100 @@ def bandwidth_gb(runtime_ms, system_size):
|
||||
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
|
||||
|
||||
|
||||
def run_bench(system_size):
|
||||
def fft(x):
|
||||
out = mx.fft.fft(x)
|
||||
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
|
||||
def fft_mlx(x):
|
||||
if dim == 1:
|
||||
out = mx.fft.fft(x)
|
||||
elif dim == 2:
|
||||
out = mx.fft.fft2(x)
|
||||
mx.eval(out)
|
||||
return out
|
||||
|
||||
bandwidths = []
|
||||
for k in range(4, 12):
|
||||
n = 2**k
|
||||
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
|
||||
x = x.astype(mx.complex64)
|
||||
mx.eval(x)
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
|
||||
def fft_mps(x):
|
||||
if dim == 1:
|
||||
out = torch.fft.fft(x)
|
||||
elif dim == 2:
|
||||
out = torch.fft.fft2(x)
|
||||
torch.mps.synchronize()
|
||||
return out
|
||||
|
||||
return bandwidths
|
||||
bandwidths = []
|
||||
for n in fft_sizes:
|
||||
batch_size = system_size // n**dim
|
||||
shape = [batch_size] + [n for _ in range(dim)]
|
||||
if backend == "mlx":
|
||||
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||
x = mx.array(x_np)
|
||||
mx.eval(x)
|
||||
fft = fft_mlx
|
||||
elif backend == "mps":
|
||||
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||
x = torch.tensor(x_np, device="mps")
|
||||
torch.mps.synchronize()
|
||||
fft = fft_mps
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
|
||||
print(n, bandwidth)
|
||||
bandwidths.append(bandwidth)
|
||||
|
||||
return np.array(bandwidths)
|
||||
|
||||
|
||||
def time_fft():
|
||||
x = np.array(range(2, 512))
|
||||
system_size = int(2**26)
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=int(2**22))
|
||||
|
||||
print("MLX GPU")
|
||||
with mx.stream(mx.gpu):
|
||||
gpu_bandwidths = run_bench(system_size=int(2**29))
|
||||
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||
|
||||
# plot bandwidths
|
||||
x = [2**k for k in range(4, 12)]
|
||||
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
|
||||
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
|
||||
plt.title("MLX FFT Benchmark")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig("fft_plot.png")
|
||||
print("MPS GPU")
|
||||
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
|
||||
|
||||
print("CPU")
|
||||
system_size = int(2**20)
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||
|
||||
x = np.array(x)
|
||||
|
||||
all_indices = x - x[0]
|
||||
radix_2to13 = (
|
||||
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
|
||||
)
|
||||
bluesteins = (
|
||||
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
|
||||
)
|
||||
|
||||
for indices, name in [
|
||||
(all_indices, "All"),
|
||||
(radix_2to13, "Radix 2-13"),
|
||||
(bluesteins, "Bluestein's"),
|
||||
]:
|
||||
# plot bandwidths
|
||||
print(name)
|
||||
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
|
||||
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
|
||||
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
|
||||
plt.title(f"MLX FFT Benchmark -- {name}")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig(f"{name}.png")
|
||||
plt.clf()
|
||||
|
||||
av_gpu_bandwidth = np.mean(gpu_bandwidths)
|
||||
av_mps_bandwidth = np.mean(mps_bandwidths)
|
||||
av_cpu_bandwidth = np.mean(cpu_bandwidths)
|
||||
print("Average bandwidths:")
|
||||
print("GPU:", av_gpu_bandwidth)
|
||||
print("MPS:", av_mps_bandwidth)
|
||||
print("CPU:", av_cpu_bandwidth)
|
||||
|
||||
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
|
||||
print("Percent MLX faster than MPS: ", portion_faster * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
70
benchmarks/python/hadamard_bench.py
Normal file
70
benchmarks/python/hadamard_bench.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import argparse
|
||||
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def had(x):
|
||||
y = mx.hadamard_transform(x)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def copy(x):
|
||||
y = x + 1.0
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def run(dtype):
|
||||
system_size = 2**26
|
||||
outputs = {}
|
||||
for test_fn in (had, copy):
|
||||
for m in [1, 12, 20, 28]:
|
||||
if test_fn == copy:
|
||||
key = "copy"
|
||||
elif m == 1:
|
||||
key = "had_2^k"
|
||||
else:
|
||||
key = "had_m*2^k"
|
||||
outputs.setdefault(key, {})
|
||||
for k in range(7, 14):
|
||||
n = m * 2**k
|
||||
if n > 2**15:
|
||||
continue
|
||||
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
|
||||
x = mx.array(x_np)
|
||||
runtime_ms = measure_runtime(test_fn, x=x)
|
||||
bytes_per_gb = 1e9
|
||||
ms_per_s = 1e3
|
||||
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
|
||||
bandwidth_gb = (
|
||||
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
|
||||
)
|
||||
print(n, bandwidth_gb)
|
||||
outputs[key][n] = bandwidth_gb
|
||||
|
||||
colors = {
|
||||
"copy": "black",
|
||||
"had_2^k": "steelblue",
|
||||
"had_m*2^k": "skyblue",
|
||||
}
|
||||
for key, output in outputs.items():
|
||||
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
|
||||
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig(f"bench_{dtype.__name__}.png")
|
||||
plt.clf()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
args = parser.parse_args()
|
||||
dtype = np.float16 if args.fp16 else np.float32
|
||||
run(dtype)
|
62
benchmarks/python/sdpa_bench.py
Normal file
62
benchmarks/python/sdpa_bench.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
MAX_SEQ = 300
|
||||
START_SEQ = 100
|
||||
SEQ_INCREMENT = 50
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
def sdpa_primitives(qs, ks, vs, alpha):
|
||||
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ vs
|
||||
return o
|
||||
|
||||
time_fn(sdpa_primitives, q, k, v, scale)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
def sdpa_fused(qs, ks, vs, alpha):
|
||||
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
|
||||
return o
|
||||
|
||||
time_fn(sdpa_fused, q, k, v, scale)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
@@ -1,36 +0,0 @@
|
||||
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
|
||||
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
|
||||
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
|
||||
@@ -62,6 +62,7 @@
|
||||
|
||||
uint64_t signaledValue() const;
|
||||
void setSignaledValue(uint64_t signaledValue);
|
||||
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
|
||||
};
|
||||
|
||||
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
|
||||
@@ -138,6 +139,11 @@
|
||||
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
|
||||
{
|
||||
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
|
||||
+}
|
||||
+
|
||||
+// method: waitUntilSignaledValue
|
||||
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
|
||||
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
|
||||
}
|
||||
|
||||
// static method: alloc
|
||||
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
|
||||
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
|
||||
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
|
||||
@@ -1906,6 +1906,9 @@
|
||||
"setShouldMaximizeConcurrentCompilation:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
|
||||
"setSignaledValue:");
|
||||
+_MTL_PRIVATE_DEF_SEL(
|
||||
+ waitUntilSignaledValue_timeoutMS_,
|
||||
+ "waitUntilSignaledValue:timeoutMS:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSize_,
|
||||
"setSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSlice_,
|
@@ -1,36 +0,0 @@
|
||||
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
|
||||
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
|
||||
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
|
||||
@@ -62,6 +62,7 @@
|
||||
|
||||
uint64_t signaledValue() const;
|
||||
void setSignaledValue(uint64_t signaledValue);
|
||||
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
|
||||
};
|
||||
|
||||
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
|
||||
@@ -138,6 +139,11 @@
|
||||
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
|
||||
{
|
||||
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
|
||||
+}
|
||||
+
|
||||
+// method: waitUntilSignaledValue
|
||||
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
|
||||
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
|
||||
}
|
||||
|
||||
// static method: alloc
|
||||
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
|
||||
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
|
||||
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
|
||||
@@ -1918,6 +1918,9 @@
|
||||
"setShouldMaximizeConcurrentCompilation:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
|
||||
"setSignaledValue:");
|
||||
+_MTL_PRIVATE_DEF_SEL(
|
||||
+ waitUntilSignaledValue_timeoutMS_,
|
||||
+ "waitUntilSignaledValue:timeoutMS:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSize_,
|
||||
"setSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSlice_,
|
@@ -1,3 +1,4 @@
|
||||
sphinx
|
||||
breathe
|
||||
sphinx-book-theme
|
||||
mlx
|
||||
|
@@ -83,3 +83,15 @@ def setup(app):
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
|
||||
latex_elements = {
|
||||
"preamble": r"""
|
||||
\usepackage{enumitem}
|
||||
\setlistdepth{5}
|
||||
\setlist[itemize,1]{label=$\bullet$}
|
||||
\setlist[itemize,2]{label=$\bullet$}
|
||||
\setlist[itemize,3]{label=$\bullet$}
|
||||
\setlist[itemize,4]{label=$\bullet$}
|
||||
\setlist[itemize,5]{label=$\bullet$}
|
||||
\renewlist{itemize}{itemize}{5}
|
||||
""",
|
||||
}
|
||||
|
421
docs/src/dev/custom_metal_kernels.rst
Normal file
421
docs/src/dev/custom_metal_kernels.rst
Normal file
@@ -0,0 +1,421 @@
|
||||
Custom Metal Kernels
|
||||
====================
|
||||
|
||||
MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||
|
||||
Simple Example
|
||||
--------------
|
||||
|
||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
.. note::
|
||||
We are only required to pass the body of the Metal kernel in ``source``.
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
* The shapes/dtypes of ``inputs``
|
||||
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
||||
so we will add ``const device float16_t* inp`` to the signature.
|
||||
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
||||
in ``source``.
|
||||
* The list of ``output_dtypes``
|
||||
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
||||
so we add ``device float16_t* out``.
|
||||
* Template parameters passed using ``template``
|
||||
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
|
||||
and instantiates the template with ``custom_kernel_myexp_float<float>``.
|
||||
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
|
||||
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
|
||||
These will be added as function arguments.
|
||||
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
|
||||
|
||||
Putting this all together, the generated function signature for ``myexp`` is as follows:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void custom_kernel_myexp_float(
|
||||
const device float16_t* inp [[buffer(0)]],
|
||||
device float16_t* out [[buffer(1)]],
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
|
||||
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
|
||||
}
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
|
||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||
when indexing.
|
||||
|
||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||
input array ``a`` if any are present in ``source``.
|
||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
# make non-contiguous
|
||||
a = a[::2]
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Complex Example
|
||||
-----------------------------
|
||||
|
||||
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
|
||||
|
||||
We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
|
||||
return output
|
||||
|
||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
|
||||
On an M1 Max, we see a big performance improvement:
|
||||
|
||||
``55.7ms -> 6.7ms => 8x speed up``
|
||||
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
|
||||
* ``atomic_outputs=True``
|
||||
Designate all of the kernel outputs as ``atomic`` in the function signature.
|
||||
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
|
||||
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
|
||||
|
||||
We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
``676.4ms -> 16.7ms => 40x speed up``
|
@@ -486,9 +486,8 @@ below.
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available and look for it
|
||||
// in the same folder as this executable if needed
|
||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
@@ -15,7 +15,7 @@ module to concisely define the model architecture.
|
||||
Attention layer
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
We will start with the llama attention layer which notably uses the RoPE
|
||||
We will start with the Llama attention layer which notably uses the RoPE
|
||||
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
||||
key/value cache that will be concatenated with the provided keys and values to
|
||||
support efficient inference.
|
||||
|
@@ -64,7 +64,7 @@ set:
|
||||
Next, setup the problem parameters and load the data. To load the data, you need our
|
||||
`mnist data loader
|
||||
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
||||
we will import as `mnist`.
|
||||
we will import as ``mnist``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@@ -43,6 +43,7 @@ are the CPU and GPU.
|
||||
usage/function_transforms
|
||||
usage/compile
|
||||
usage/numpy
|
||||
usage/distributed
|
||||
usage/using_streams
|
||||
|
||||
.. toctree::
|
||||
@@ -69,6 +70,7 @@ are the CPU and GPU.
|
||||
python/metal
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
python/tree_utils
|
||||
|
||||
.. toctree::
|
||||
@@ -83,3 +85,4 @@ are the CPU and GPU.
|
||||
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
dev/custom_metal_kernels
|
||||
|
@@ -70,36 +70,36 @@ To build and install the MLX python library from source, first, clone MLX from
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
|
||||
Then simply build and install MLX using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
||||
|
||||
For developing use an editable install:
|
||||
For developing, install the package with development dependencies, and use an
|
||||
editable install:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
||||
|
||||
To make sure the install is working run the tests with:
|
||||
Once the development dependencies are installed, you can build faster with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
||||
|
||||
Run the tests with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install ".[testing]"
|
||||
python -m unittest discover python/tests
|
||||
|
||||
Optional: Install stubs to enable auto completions and type checking from your IDE:
|
||||
Optional: Install stubs to enable auto completions and type checking from your
|
||||
IDE:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install ".[dev]"
|
||||
python setup.py generate_stubs
|
||||
|
||||
C++ API
|
||||
@@ -186,8 +186,8 @@ should point to the path to the built metal library.
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
|
||||
and `BUILD_SHARED_LIBS=ON`.
|
||||
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
|
||||
and ``BUILD_SHARED_LIBS=ON``.
|
||||
|
||||
The MLX CMake build has several additional options to make smaller binaries.
|
||||
For example, if you don't need the CPU backend or support for safetensors and
|
||||
@@ -195,7 +195,7 @@ GGUF, you can do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
cmake ..
|
||||
cmake .. \
|
||||
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
@@ -203,7 +203,7 @@ GGUF, you can do:
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
|
||||
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
|
||||
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
|
||||
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
|
@@ -24,6 +24,7 @@ Array
|
||||
array.any
|
||||
array.argmax
|
||||
array.argmin
|
||||
array.conj
|
||||
array.cos
|
||||
array.cummax
|
||||
array.cummin
|
||||
@@ -52,8 +53,10 @@ Array
|
||||
array.sqrt
|
||||
array.square
|
||||
array.squeeze
|
||||
array.swapaxes
|
||||
array.std
|
||||
array.sum
|
||||
array.swapaxes
|
||||
array.transpose
|
||||
array.T
|
||||
array.var
|
||||
array.view
|
||||
|
22
docs/src/python/distributed.rst
Normal file
22
docs/src/python/distributed.rst
Normal file
@@ -0,0 +1,22 @@
|
||||
.. _distributed:
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
Distributed Communication
|
||||
==========================
|
||||
|
||||
MLX provides a distributed communication package using MPI. The MPI library is
|
||||
loaded at runtime; if MPI is available then distributed communication is also
|
||||
made available.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Group
|
||||
is_available
|
||||
init
|
||||
all_sum
|
||||
all_gather
|
||||
send
|
||||
recv
|
||||
recv_like
|
@@ -12,3 +12,5 @@ Fast
|
||||
layer_norm
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
affine_quantize
|
||||
metal_kernel
|
||||
|
@@ -9,7 +9,9 @@ Linear Algebra
|
||||
:toctree: _autosummary
|
||||
|
||||
inv
|
||||
tri_inv
|
||||
norm
|
||||
cholesky
|
||||
cholesky_inv
|
||||
qr
|
||||
svd
|
||||
|
@@ -17,6 +17,8 @@ simple functions.
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
glu
|
||||
hard_shrink
|
||||
hard_tanh
|
||||
hardswish
|
||||
leaky_relu
|
||||
log_sigmoid
|
||||
@@ -29,6 +31,7 @@ simple functions.
|
||||
sigmoid
|
||||
silu
|
||||
softmax
|
||||
softmin
|
||||
softplus
|
||||
softshrink
|
||||
step
|
||||
|
@@ -16,15 +16,23 @@ Layers
|
||||
Conv1d
|
||||
Conv2d
|
||||
Conv3d
|
||||
ConvTranspose1d
|
||||
ConvTranspose2d
|
||||
ConvTranspose3d
|
||||
Dropout
|
||||
Dropout2d
|
||||
Dropout3d
|
||||
Embedding
|
||||
GELU
|
||||
GLU
|
||||
GroupNorm
|
||||
GRU
|
||||
HardShrink
|
||||
HardTanh
|
||||
Hardswish
|
||||
InstanceNorm
|
||||
LayerNorm
|
||||
LeakyReLU
|
||||
Linear
|
||||
LSTM
|
||||
MaxPool1d
|
||||
@@ -36,13 +44,19 @@ Layers
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
ReLU6
|
||||
RNN
|
||||
RoPE
|
||||
SELU
|
||||
Sequential
|
||||
SiLU
|
||||
SinusoidalPositionalEncoding
|
||||
Softmin
|
||||
Softshrink
|
||||
Softsign
|
||||
Softmax
|
||||
Softplus
|
||||
Step
|
||||
Tanh
|
||||
Transformer
|
||||
Upsample
|
||||
|
@@ -44,6 +44,10 @@ Operations
|
||||
convolve
|
||||
conv1d
|
||||
conv2d
|
||||
conv3d
|
||||
conv_transpose1d
|
||||
conv_transpose2d
|
||||
conv_transpose3d
|
||||
conv_general
|
||||
cos
|
||||
cosh
|
||||
@@ -57,6 +61,8 @@ Operations
|
||||
diagonal
|
||||
divide
|
||||
divmod
|
||||
einsum
|
||||
einsum_path
|
||||
equal
|
||||
erf
|
||||
erfinv
|
||||
@@ -72,8 +78,10 @@ Operations
|
||||
gather_qmm
|
||||
greater
|
||||
greater_equal
|
||||
hadamard_transform
|
||||
identity
|
||||
inner
|
||||
isfinite
|
||||
isclose
|
||||
isinf
|
||||
isnan
|
||||
@@ -103,6 +111,7 @@ Operations
|
||||
minimum
|
||||
moveaxis
|
||||
multiply
|
||||
nan_to_num
|
||||
negative
|
||||
not_equal
|
||||
ones
|
||||
@@ -156,6 +165,7 @@ Operations
|
||||
tril
|
||||
triu
|
||||
var
|
||||
view
|
||||
where
|
||||
zeros
|
||||
zeros_like
|
||||
|
@@ -31,6 +31,41 @@ model's parameters and the **optimizer state**.
|
||||
# Compute the new parameters but also the optimizer state.
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
Saving and Loading
|
||||
------------------
|
||||
|
||||
To serialize an optimizer, save its state. To load an optimizer, load and set
|
||||
the saved state. Here's a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
import mlx.optimizers as optim
|
||||
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
# Perform some updates with the optimizer
|
||||
model = {"w" : mx.zeros((5, 5))}
|
||||
grads = {"w" : mx.ones((5, 5))}
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Save the state
|
||||
state = tree_flatten(optimizer.state)
|
||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||
|
||||
# Later on, for example when loading from a checkpoint,
|
||||
# recreate the optimizer and load the state
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||
optimizer.state = state
|
||||
|
||||
Note, not every optimizer configuation parameter is saved in the state. For
|
||||
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
|
||||
parameters are not. A good rule of thumb is if the parameter can be scheduled
|
||||
then it will be included in the optimizer state.
|
||||
|
||||
.. toctree::
|
||||
|
||||
optimizers/optimizer
|
||||
|
@@ -44,3 +44,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||
split
|
||||
truncated_normal
|
||||
uniform
|
||||
laplace
|
||||
|
@@ -10,6 +10,7 @@ Transforms
|
||||
|
||||
eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
enable_compile
|
||||
grad
|
||||
|
166
docs/src/usage/distributed.rst
Normal file
166
docs/src/usage/distributed.rst
Normal file
@@ -0,0 +1,166 @@
|
||||
.. _usage_distributed:
|
||||
|
||||
Distributed Communication
|
||||
=========================
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
|
||||
provide distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. You can
|
||||
see a list of the supported operations in the :ref:`API docs<distributed>`.
|
||||
|
||||
.. note::
|
||||
A lot of operations may not be supported or not as fast as they should be.
|
||||
We are adding more and tuning the ones we have as we are figuring out the
|
||||
best way to do distributed computing on Macs using MLX.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||
machine. The minimal distributed program in MLX is as simple as:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
world = mx.distributed.init()
|
||||
x = mx.distributed.all_sum(mx.ones(10))
|
||||
print(world.rank(), x)
|
||||
|
||||
The program above sums the array ``mx.ones(10)`` across all
|
||||
distributed processes. If simply run with ``python``, however, only one
|
||||
process is launched and no distributed communication takes place.
|
||||
|
||||
To launch the program in distributed mode we need to use ``mpirun`` or
|
||||
``mpiexec`` depending on the MPI installation. The simplest possible way is the
|
||||
following:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 python test.py
|
||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
|
||||
The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
|
||||
print 4 etc.
|
||||
|
||||
Installing MPI
|
||||
---------------
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ conda install openmpi
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
|
||||
Setting up Remote Hosts
|
||||
-----------------------
|
||||
|
||||
MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:
|
||||
|
||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||
password or host confirmation
|
||||
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
|
||||
full path to force all machines to use a specific path.
|
||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||
in the ``.ssh/config`` files on all machines.
|
||||
|
||||
.. note::
|
||||
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
|
||||
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
|
||||
|
||||
An easy way to pass the host names to MPI is using a host file. A host file
|
||||
looks like the following, where ``host1`` and ``host2`` should be the fully
|
||||
qualified domain names or IPs for these hosts.
|
||||
|
||||
.. code::
|
||||
|
||||
host1 slots=1
|
||||
host2 slots=1
|
||||
|
||||
When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
||||
process per host. The hostfile also needs to contain the current
|
||||
host if you want to run on the local host. Passing the host file to
|
||||
``mpirun`` is simply done using the ``--hostfile`` command line argument.
|
||||
|
||||
Training Example
|
||||
----------------
|
||||
|
||||
In this section we will adapt an MLX training loop to support data parallel
|
||||
distributed training. Namely, we will average the gradients across a set of
|
||||
hosts before applying them to the model.
|
||||
|
||||
Our training loop looks like the following code snippet if we omit the model,
|
||||
dataset and optimizer initialization.
|
||||
|
||||
.. code:: python
|
||||
|
||||
model = ...
|
||||
optimizer = ...
|
||||
dataset = ...
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for x, y in dataset:
|
||||
loss = step(model, x, y)
|
||||
mx.eval(loss, model.parameters())
|
||||
|
||||
All we have to do to average the gradients across machines is perform an
|
||||
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
|
||||
have to :func:`mlx.utils.tree_map` the gradients with following function.
|
||||
|
||||
.. code:: python
|
||||
|
||||
def all_avg(x):
|
||||
return mx.distributed.all_sum(x) / mx.distributed.init().size()
|
||||
|
||||
Putting everything together our training loop step looks as follows with
|
||||
everything else remaining the same.
|
||||
|
||||
.. code:: python
|
||||
|
||||
from mlx.utils import tree_map
|
||||
|
||||
def all_reduce_grads(grads):
|
||||
N = mx.distributed.init()
|
||||
if N == 1:
|
||||
return grads
|
||||
return tree_map(
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads)
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = all_reduce_grads(grads) # <--- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
Tuning All Reduce
|
||||
-----------------
|
||||
|
||||
We are working on improving the performance of all reduce on MLX but for now
|
||||
the two main things one can do to extract the most out of distributed training with MLX are:
|
||||
|
||||
1. Perform a few large reductions instead of many small ones to improve
|
||||
bandwidth and latency
|
||||
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
|
||||
connections between each host to improve bandwidth
|
@@ -3,7 +3,11 @@
|
||||
Conversion to NumPy and Other Frameworks
|
||||
========================================
|
||||
|
||||
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
MLX array supports conversion between other frameworks with either:
|
||||
|
||||
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
||||
|
||||
Let's convert an array to NumPy and back.
|
||||
|
||||
.. code-block:: python
|
||||
|
@@ -16,7 +16,7 @@ int main() {
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_reduce_sum(x, global_group);
|
||||
array out = distributed::all_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
||||
|
@@ -249,9 +249,8 @@ void Axpby::eval_gpu(
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available and look for it
|
||||
// in the same folder as this executable if needed
|
||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
@@ -2,7 +2,7 @@
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.24",
|
||||
"mlx>=0.9.0",
|
||||
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
|
||||
"mlx>=0.17.0",
|
||||
"nanobind==2.1.0",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
mlx>=0.9.0
|
||||
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
mlx>=0.17.0
|
||||
nanobind==2.1.0
|
||||
|
@@ -13,7 +13,6 @@ if __name__ == "__main__":
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages=["mlx_sample_extensions"],
|
||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||
extras_require={"dev": []},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.8",
|
||||
)
|
||||
|
@@ -6,6 +6,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
|
@@ -23,11 +23,22 @@ void free(Buffer buffer) {
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
return Buffer{std::malloc(size)};
|
||||
void* ptr = std::malloc(size + sizeof(size_t));
|
||||
if (ptr != nullptr) {
|
||||
*static_cast<size_t*>(ptr) = size;
|
||||
}
|
||||
return Buffer{ptr};
|
||||
}
|
||||
|
||||
void CommonAllocator::free(Buffer buffer) {
|
||||
std::free(buffer.raw_ptr());
|
||||
std::free(buffer.ptr());
|
||||
}
|
||||
|
||||
size_t CommonAllocator::size(Buffer buffer) const {
|
||||
if (buffer.ptr() == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return *static_cast<size_t*>(buffer.ptr());
|
||||
}
|
||||
|
||||
Buffer malloc_or_wait(size_t size) {
|
||||
|
@@ -41,6 +41,7 @@ class Allocator {
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
virtual size_t size(Buffer buffer) const = 0;
|
||||
|
||||
Allocator() = default;
|
||||
Allocator(const Allocator& other) = delete;
|
||||
@@ -57,6 +58,7 @@ class CommonAllocator : public Allocator {
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
|
||||
private:
|
||||
CommonAllocator() = default;
|
||||
|
@@ -17,6 +17,10 @@ bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
@@ -102,7 +106,7 @@ void array::eval() {
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing();
|
||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
@@ -171,10 +175,11 @@ array::~array() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore arrays that will be detached
|
||||
if (status() != array::Status::unscheduled) {
|
||||
// Ignore arrays that might be detached during eval
|
||||
if (status() == array::Status::scheduled) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Break circular reference for non-detached arrays with siblings
|
||||
if (auto n = siblings().size(); n > 0) {
|
||||
bool do_detach = true;
|
||||
@@ -206,7 +211,7 @@ void array::ArrayDesc::init() {
|
||||
strides[i] = size;
|
||||
size *= shape[i];
|
||||
}
|
||||
for (auto& in : inputs) {
|
||||
for (const auto& in : inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
}
|
||||
}
|
||||
@@ -231,7 +236,7 @@ array::ArrayDesc::ArrayDesc(
|
||||
|
||||
array::ArrayDesc::~ArrayDesc() {
|
||||
// When an array description is destroyed it will delete a bunch of arrays
|
||||
// that may also destory their corresponding descriptions and so on and so
|
||||
// that may also destroy their corresponding descriptions and so on and so
|
||||
// forth.
|
||||
//
|
||||
// This calls recursively the destructor and can result in stack overflow, we
|
||||
|
86
mlx/array.h
86
mlx/array.h
@@ -73,32 +73,32 @@ class array {
|
||||
this->array_desc_ = other.array_desc_;
|
||||
}
|
||||
return *this;
|
||||
};
|
||||
}
|
||||
|
||||
/** The size of the array's datatype in bytes. */
|
||||
size_t itemsize() const {
|
||||
return size_of(dtype());
|
||||
};
|
||||
}
|
||||
|
||||
/** The number of elements in the array. */
|
||||
size_t size() const {
|
||||
return array_desc_->size;
|
||||
};
|
||||
}
|
||||
|
||||
/** The number of bytes in the array. */
|
||||
size_t nbytes() const {
|
||||
return size() * itemsize();
|
||||
};
|
||||
}
|
||||
|
||||
/** The number of dimensions of the array. */
|
||||
size_t ndim() const {
|
||||
return array_desc_->shape.size();
|
||||
};
|
||||
}
|
||||
|
||||
/** The shape of the array as a vector of integers. */
|
||||
const std::vector<int>& shape() const {
|
||||
return array_desc_->shape;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the size of the corresponding dimension.
|
||||
@@ -107,12 +107,12 @@ class array {
|
||||
* bounds checking. */
|
||||
int shape(int dim) const {
|
||||
return shape().at(dim < 0 ? dim + ndim() : dim);
|
||||
};
|
||||
}
|
||||
|
||||
/** The strides of the array. */
|
||||
const std::vector<size_t>& strides() const {
|
||||
return array_desc_->strides;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the stride of the corresponding dimension.
|
||||
@@ -121,12 +121,12 @@ class array {
|
||||
* bounds checking. */
|
||||
size_t strides(int dim) const {
|
||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||
};
|
||||
}
|
||||
|
||||
/** Get the arrays data type. */
|
||||
Dtype dtype() const {
|
||||
return array_desc_->dtype;
|
||||
};
|
||||
}
|
||||
|
||||
/** Evaluate the array. */
|
||||
void eval();
|
||||
@@ -160,10 +160,10 @@ class array {
|
||||
|
||||
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
|
||||
return a.arr.id() == b.arr.id() && a.idx == b.idx;
|
||||
};
|
||||
}
|
||||
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
|
||||
return !(a == b);
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
const array& arr;
|
||||
@@ -209,7 +209,7 @@ class array {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
: buffer(buffer), d(d) {};
|
||||
: buffer(buffer), d(d) {}
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
@@ -219,33 +219,45 @@ class array {
|
||||
};
|
||||
|
||||
struct Flags {
|
||||
// True if there are no gaps in the underlying data. Each item
|
||||
// True iff there are no gaps in the underlying data. Each item
|
||||
// in the underlying data buffer belongs to at least one index.
|
||||
//
|
||||
// True iff:
|
||||
// prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()
|
||||
bool contiguous : 1;
|
||||
|
||||
// True iff:
|
||||
// strides[-1] == 1 and
|
||||
// all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in
|
||||
// range(ndim - 1))
|
||||
bool row_contiguous : 1;
|
||||
|
||||
// True iff:
|
||||
// strides[0] == 1 and
|
||||
// all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in
|
||||
// range(1, ndim))
|
||||
bool col_contiguous : 1;
|
||||
};
|
||||
|
||||
/** The array's primitive. */
|
||||
Primitive& primitive() const {
|
||||
return *(array_desc_->primitive);
|
||||
};
|
||||
}
|
||||
|
||||
/** A shared pointer to the array's primitive. */
|
||||
std::shared_ptr<Primitive>& primitive_ptr() const {
|
||||
return array_desc_->primitive;
|
||||
};
|
||||
}
|
||||
|
||||
/** Check if the array has an attached primitive or is a leaf node. */
|
||||
bool has_primitive() const {
|
||||
return array_desc_->primitive != nullptr;
|
||||
};
|
||||
}
|
||||
|
||||
/** The array's inputs. */
|
||||
const std::vector<array>& inputs() const {
|
||||
return array_desc_->inputs;
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array>& inputs() {
|
||||
return array_desc_->inputs;
|
||||
@@ -259,12 +271,12 @@ class array {
|
||||
/** The array's siblings. */
|
||||
const std::vector<array>& siblings() const {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
}
|
||||
|
||||
/** The array's siblings. */
|
||||
std::vector<array>& siblings() {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
}
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
@@ -281,7 +293,7 @@ class array {
|
||||
outputs.push_back(*this);
|
||||
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
|
||||
return outputs;
|
||||
};
|
||||
}
|
||||
|
||||
/** Detach the array from the graph. */
|
||||
void detach();
|
||||
@@ -289,19 +301,32 @@ class array {
|
||||
/** Get the Flags bit-field. */
|
||||
const Flags& flags() const {
|
||||
return array_desc_->flags;
|
||||
};
|
||||
}
|
||||
|
||||
/** The size (in elements) of the underlying buffer the array points to. */
|
||||
/** The size (in elements) of the underlying buffer the array points to.
|
||||
*
|
||||
* This can be different than the actual size of the array if the array has
|
||||
* been broadcast or irregularly strided. If ``first`` is the offset into
|
||||
* the data buffer of the first element of the array (i.e. the offset
|
||||
* corresponding to ``arr[0, 0, ...]``) and last is the offset into the
|
||||
* data buffer of the last element of the array (i.e. the offset
|
||||
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
|
||||
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
|
||||
**/
|
||||
size_t data_size() const {
|
||||
return array_desc_->data_size;
|
||||
};
|
||||
}
|
||||
|
||||
allocator::Buffer& buffer() {
|
||||
return array_desc_->data->buffer;
|
||||
};
|
||||
}
|
||||
const allocator::Buffer& buffer() const {
|
||||
return array_desc_->data->buffer;
|
||||
};
|
||||
}
|
||||
|
||||
size_t buffer_size() const {
|
||||
return allocator::allocator().size(buffer());
|
||||
}
|
||||
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
@@ -312,19 +337,20 @@ class array {
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* data() const {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
};
|
||||
}
|
||||
|
||||
enum Status { unscheduled, scheduled, available };
|
||||
|
||||
bool is_available() const {
|
||||
return status() == Status::available;
|
||||
}
|
||||
const Status status() const {
|
||||
|
||||
Status status() const {
|
||||
return array_desc_->status;
|
||||
}
|
||||
|
||||
@@ -411,8 +437,6 @@ class array {
|
||||
void* data_ptr{nullptr};
|
||||
|
||||
// The size in elements of the data buffer the array accesses
|
||||
// This can be different than the actual size of the array if it
|
||||
// has been broadcast or irregularly strided.
|
||||
size_t data_size;
|
||||
|
||||
// Contains useful meta data about the array
|
||||
|
@@ -1,9 +1,9 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
#include <vecLib/vDSP.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
@@ -2,8 +2,7 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <vecLib/BNNS/bnns.h>
|
||||
#include <vecLib/cblas_new.h>
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/backend/accelerate/utils.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
|
@@ -3,8 +3,7 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include <vecLib/vDSP.h>
|
||||
#include <vecLib/vForce.h>
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
@@ -37,7 +36,7 @@ DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(NumberOfElements)
|
||||
@@ -51,6 +50,7 @@ DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
@@ -102,7 +102,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -117,7 +117,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary(
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -132,7 +132,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,7 +287,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == int32) {
|
||||
binary(
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -300,7 +300,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == float32) {
|
||||
binary(
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -315,7 +315,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,12 +326,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[exp] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,12 +389,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto size = in.data_size();
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[log1p] Cannot compute log of elements in array with"
|
||||
" non floating point type.");
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -408,7 +400,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -423,7 +415,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x * y; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -434,7 +426,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else {
|
||||
unary(in, out, [](auto x) { return -x; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -521,7 +513,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto size = in.data_size();
|
||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
} else {
|
||||
unary(in, out, [](auto x) { return x * x; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -547,7 +539,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -565,7 +557,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary(
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -577,7 +569,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
},
|
||||
UseDefaultBinaryOp());
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return x - y; });
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -2,8 +2,8 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
#include <vecLib/vDSP.h>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
@@ -3,7 +3,10 @@
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
@@ -53,25 +56,25 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
|
||||
return (*(simd_float16*)&epart) * x;
|
||||
}
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
/**
|
||||
* The ARM neon equivalent of the fast exp above.
|
||||
*/
|
||||
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
|
||||
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
|
||||
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
|
||||
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
|
||||
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
|
||||
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
|
||||
|
||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
|
||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
|
||||
float16x8_t fpart = vsubq_f16(x, ipart);
|
||||
|
||||
x = vdupq_n_f16(1.535336188319500e-4f);
|
||||
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
|
||||
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
@@ -107,53 +110,6 @@ inline float16_t neon_reduce_add(float16x8_t x) {
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct AccelerateSimdOps {
|
||||
VT init(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return *(VT*)a;
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
*(VT*)dst = x;
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
};
|
||||
|
||||
VT exp(VT x) {
|
||||
return simd_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return simd_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return simd_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct NeonFp16SimdOps {
|
||||
VT init(T a) {
|
||||
@@ -170,7 +126,7 @@ struct NeonFp16SimdOps {
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return vmaxq_f16(a, b);
|
||||
};
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return neon_fast_exp(x);
|
||||
@@ -201,6 +157,55 @@ struct NeonFp16SimdOps {
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct AccelerateSimdOps {
|
||||
VT init(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return *(VT*)a;
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
*(VT*)dst = x;
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return simd_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return simd_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return simd_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
@@ -362,12 +367,16 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
} else {
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
softmax<
|
||||
float16_t,
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
eval(inputs, out); // Redirect to common backend for consistency
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
|
@@ -1,8 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vecLib/BNNS/bnns.h>
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include "mlx/dtype.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@@ -42,12 +42,15 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
|
@@ -196,6 +196,20 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
|
@@ -43,13 +43,15 @@ void set_binary_op_output_data(
|
||||
array& out,
|
||||
BinaryOpType bopt,
|
||||
bool donate_with_move = false) {
|
||||
bool b_donatable = is_donatable(b, out);
|
||||
bool a_donatable = is_donatable(a, out);
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
if (b_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
@@ -64,7 +66,7 @@ void set_binary_op_output_data(
|
||||
}
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
if (a_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
@@ -79,13 +81,13 @@ void set_binary_op_output_data(
|
||||
}
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
if (a_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
} else if (b_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
@@ -100,16 +102,14 @@ void set_binary_op_output_data(
|
||||
}
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
if (a.is_donatable() && a.flags().row_contiguous &&
|
||||
a.itemsize() == out.itemsize() && a.size() == out.size()) {
|
||||
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
} else if (
|
||||
b.is_donatable() && b.flags().row_contiguous &&
|
||||
b.itemsize() == out.itemsize() && b.size() == out.size()) {
|
||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
|
@@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void CustomVJP::eval(
|
||||
void CustomTransforms::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
@@ -250,49 +250,6 @@ void Split::eval(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
|
||||
copy_needed |= strides_[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void Slice::shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
|
@@ -18,7 +18,8 @@ void print_constant(std::ostream& os, const array& x) {
|
||||
case complex64:
|
||||
return print_complex_constant<complex64_t>(os, x);
|
||||
case int8:
|
||||
return print_int_constant<int8_t>(os, x);
|
||||
os << static_cast<int32_t>(x.item<int8_t>());
|
||||
return;
|
||||
case int16:
|
||||
return print_int_constant<int16_t>(os, x);
|
||||
case int32:
|
||||
@@ -26,7 +27,8 @@ void print_constant(std::ostream& os, const array& x) {
|
||||
case int64:
|
||||
return print_int_constant<int64_t>(os, x);
|
||||
case uint8:
|
||||
return print_int_constant<uint8_t>(os, x);
|
||||
os << static_cast<uint32_t>(x.item<uint8_t>());
|
||||
return;
|
||||
case uint16:
|
||||
return print_int_constant<uint16_t>(os, x);
|
||||
case uint32:
|
||||
@@ -205,8 +207,8 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o].move_shared_buffer(
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <list>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
|
@@ -1125,7 +1125,7 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
else {
|
||||
std::ostringstream msg;
|
||||
msg << "[Convolution::eval] Convolution currently only supports"
|
||||
<< " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2
|
||||
<< " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2
|
||||
<< " spatial dimensions";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -142,29 +143,31 @@ void copy_general(
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
switch (src.ndim()) {
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector<std::vector<stride_t>>{i_strides});
|
||||
switch (new_shape.size()) {
|
||||
case 1:
|
||||
copy_general_dim1<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
src, dst, new_shape, new_strides[0], i_offset);
|
||||
return;
|
||||
case 2:
|
||||
copy_general_dim2<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
src, dst, new_shape, new_strides[0], i_offset);
|
||||
return;
|
||||
case 3:
|
||||
copy_general_dim3<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
src, dst, new_shape, new_strides[0], i_offset);
|
||||
return;
|
||||
case 4:
|
||||
copy_general_dim4<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
src, dst, new_shape, new_strides[0], i_offset);
|
||||
return;
|
||||
}
|
||||
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
for (size_t i = 0; i < dst.size(); ++i) {
|
||||
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
|
||||
stride_t src_elem = elem_to_loc(i, new_shape, new_strides[0]);
|
||||
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
|
||||
}
|
||||
}
|
||||
@@ -195,10 +198,10 @@ inline void copy_general_general_dims(
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
stride_t i_offset,
|
||||
stride_t o_offset) {
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
if constexpr (D > 1) {
|
||||
int axis = src.ndim() - D;
|
||||
int axis = data_shape.size() - D;
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
auto N = data_shape[axis];
|
||||
@@ -209,7 +212,7 @@ inline void copy_general_general_dims(
|
||||
o_offset += stride_dst;
|
||||
}
|
||||
} else {
|
||||
int axis = src.ndim() - 1;
|
||||
int axis = data_shape.size() - 1;
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
auto N = data_shape[axis];
|
||||
@@ -230,38 +233,76 @@ void copy_general_general(
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
stride_t i_offset,
|
||||
stride_t o_offset) {
|
||||
switch (src.ndim()) {
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector<std::vector<stride_t>>{i_strides, o_strides});
|
||||
switch (new_shape.size()) {
|
||||
case 1:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
src,
|
||||
dst,
|
||||
new_shape,
|
||||
new_strides[0],
|
||||
new_strides[1],
|
||||
i_offset,
|
||||
o_offset);
|
||||
return;
|
||||
case 2:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
src,
|
||||
dst,
|
||||
new_shape,
|
||||
new_strides[0],
|
||||
new_strides[1],
|
||||
i_offset,
|
||||
o_offset);
|
||||
return;
|
||||
case 3:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
src,
|
||||
dst,
|
||||
new_shape,
|
||||
new_strides[0],
|
||||
new_strides[1],
|
||||
i_offset,
|
||||
o_offset);
|
||||
return;
|
||||
case 4:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
src,
|
||||
dst,
|
||||
new_shape,
|
||||
new_strides[0],
|
||||
new_strides[1],
|
||||
i_offset,
|
||||
o_offset);
|
||||
return;
|
||||
case 5:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
src,
|
||||
dst,
|
||||
new_shape,
|
||||
new_strides[0],
|
||||
new_strides[1],
|
||||
i_offset,
|
||||
o_offset);
|
||||
return;
|
||||
}
|
||||
|
||||
int size = std::accumulate(
|
||||
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||
new_shape.end() - 5, new_shape.end(), 1, std::multiplies<int>());
|
||||
for (int i = 0; i < src.size(); i += size) {
|
||||
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
|
||||
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
|
||||
stride_t src_offset = i_offset + elem_to_loc(i, new_shape, new_strides[0]);
|
||||
stride_t dst_offset = o_offset + elem_to_loc(i, new_shape, new_strides[1]);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
|
||||
src,
|
||||
dst,
|
||||
new_shape,
|
||||
new_strides[0],
|
||||
new_strides[1],
|
||||
src_offset,
|
||||
dst_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -444,8 +485,17 @@ void copy_inplace(
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void copy_inplace<int64_t>(
|
||||
template void copy_inplace<size_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<size_t>& i_strides,
|
||||
const std::vector<size_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
template void copy_inplace<int64_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
@@ -453,24 +503,6 @@ void copy_inplace<int64_t>(
|
||||
const std::vector<int64_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype) {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
return copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset);
|
||||
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
return copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
}
|
||||
CopyType ctype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -5,7 +5,6 @@
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
@@ -53,7 +52,7 @@ DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(NumberOfElements)
|
||||
@@ -69,6 +68,7 @@ DEFAULT(Full)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
|
107
mlx/backend/common/hadamard.cpp
Normal file
107
mlx/backend/common/hadamard.cpp
Normal file
@@ -0,0 +1,107 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// n = 2^k component
|
||||
template <typename T>
|
||||
void hadamard_n(array& out, int n, int m, float scale) {
|
||||
for (int b = 0; b < out.size() / n; b++) {
|
||||
size_t loc = b * n;
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
int h = 1;
|
||||
int n_over_2 = n / 2;
|
||||
while (h < n) {
|
||||
for (int i = 0; i < n / 2; i++) {
|
||||
int k = i & (h - 1);
|
||||
int j = ((i - k) << 1) + k;
|
||||
float x = *(data_ptr + j);
|
||||
float y = *(data_ptr + j + h);
|
||||
*(data_ptr + j) = x + y;
|
||||
*(data_ptr + j + h) = x - y;
|
||||
if (h == n_over_2) {
|
||||
*(data_ptr + j) *= scale;
|
||||
*(data_ptr + j + h) *= scale;
|
||||
}
|
||||
}
|
||||
h <<= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// m component
|
||||
template <typename T>
|
||||
void hadamard_m(array& out, int n, int m, float scale) {
|
||||
auto h_matrices = hadamard_matrices();
|
||||
auto& matrix = h_matrices[m];
|
||||
auto start = 1;
|
||||
auto end = matrix.find('\n', start);
|
||||
std::vector<bool> hmat_vec;
|
||||
while (end != std::string_view::npos) {
|
||||
auto row = matrix.substr(start, end - start);
|
||||
for (int i = 0; i < row.length(); i++) {
|
||||
hmat_vec.push_back(row[i] == '+');
|
||||
}
|
||||
start = end + 1;
|
||||
end = matrix.find('\n', start);
|
||||
}
|
||||
|
||||
for (int b = 0; b < out.size() / m / n; b++) {
|
||||
size_t loc = b * n * m;
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
for (int i = 0; i < n; i++) {
|
||||
std::vector<float> out(m);
|
||||
for (int j = 0; j < m; j++) {
|
||||
for (int k = 0; k < m; k++) {
|
||||
float x = *(data_ptr + i + k * n);
|
||||
if (hmat_vec[k + j * m]) {
|
||||
out[j] += x;
|
||||
} else {
|
||||
out[j] -= x;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < m; j++) {
|
||||
*(data_ptr + i + j * n) = out[j] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void hadamard(array& out, int n, int m, float scale) {
|
||||
float n_scale = m > 1 ? 1.0 : scale;
|
||||
hadamard_n<T>(out, n, m, n_scale);
|
||||
if (m > 1) {
|
||||
hadamard_m<T>(out, n, m, scale);
|
||||
}
|
||||
}
|
||||
|
||||
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Copy input to output
|
||||
copy(in, out, CopyType::General);
|
||||
|
||||
int axis = out.ndim() - 1;
|
||||
auto [n, m] = decompose_hadamard(out.shape(axis));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case float32:
|
||||
return hadamard<float>(out, n, m, scale_);
|
||||
case float16:
|
||||
return hadamard<float16_t>(out, n, m, scale_);
|
||||
case bfloat16:
|
||||
return hadamard<bfloat16_t>(out, n, m, scale_);
|
||||
default:
|
||||
throw std::invalid_argument("[hadamard] Unsupported type.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
105
mlx/backend/common/hadamard.h
Normal file
105
mlx/backend/common/hadamard.h
Normal file
@@ -0,0 +1,105 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// From http://neilsloane.com/hadamard/
|
||||
constexpr std::string_view h12 = R"(
|
||||
+-++++++++++
|
||||
--+-+-+-+-+-
|
||||
+++-++----++
|
||||
+---+--+-++-
|
||||
+++++-++----
|
||||
+-+---+--+-+
|
||||
++--+++-++--
|
||||
+--++---+--+
|
||||
++----+++-++
|
||||
+--+-++---+-
|
||||
++++----+++-
|
||||
+-+--+-++---
|
||||
)";
|
||||
|
||||
constexpr std::string_view h20 = R"(
|
||||
+----+----++--++-++-
|
||||
-+----+---+++---+-++
|
||||
--+----+---+++-+-+-+
|
||||
---+----+---+++++-+-
|
||||
----+----++--++-++-+
|
||||
-+++++-----+--+++--+
|
||||
+-+++-+---+-+--+++--
|
||||
++-++--+---+-+--+++-
|
||||
+++-+---+---+-+--+++
|
||||
++++-----++--+-+--++
|
||||
--++-+-++-+-----++++
|
||||
---++-+-++-+---+-+++
|
||||
+---++-+-+--+--++-++
|
||||
++---++-+----+-+++-+
|
||||
-++---++-+----+++++-
|
||||
-+--+--++-+----+----
|
||||
+-+-----++-+----+---
|
||||
-+-+-+---+--+----+--
|
||||
--+-+++------+----+-
|
||||
+--+--++------+----+
|
||||
)";
|
||||
|
||||
constexpr std::string_view h28 = R"(
|
||||
+------++----++-+--+-+--++--
|
||||
-+-----+++-----+-+--+-+--++-
|
||||
--+-----+++---+-+-+----+--++
|
||||
---+-----+++---+-+-+-+--+--+
|
||||
----+-----+++---+-+-+++--+--
|
||||
-----+-----++++--+-+--++--+-
|
||||
------++----++-+--+-+--++--+
|
||||
--++++-+-------++--+++-+--+-
|
||||
---++++-+-----+-++--+-+-+--+
|
||||
+---+++--+----++-++--+-+-+--
|
||||
++---++---+----++-++--+-+-+-
|
||||
+++---+----+----++-++--+-+-+
|
||||
++++--------+-+--++-++--+-+-
|
||||
-++++--------+++--++--+--+-+
|
||||
-+-++-++--++--+--------++++-
|
||||
+-+-++--+--++--+--------++++
|
||||
-+-+-++--+--++--+----+---+++
|
||||
+-+-+-++--+--+---+---++---++
|
||||
++-+-+-++--+------+--+++---+
|
||||
-++-+-+-++--+------+-++++---
|
||||
+-++-+---++--+------+-++++--
|
||||
-++--++-+-++-+++----++------
|
||||
+-++--++-+-++-+++-----+-----
|
||||
++-++---+-+-++-+++-----+----
|
||||
-++-++-+-+-+-+--+++-----+---
|
||||
--++-++++-+-+----+++-----+--
|
||||
+--++-+-++-+-+----+++-----+-
|
||||
++--++-+-++-+-+----++------+
|
||||
)";
|
||||
|
||||
inline const std::map<int, std::string_view> hadamard_matrices() {
|
||||
return {{12, h12}, {20, h20}, {28, h28}};
|
||||
}
|
||||
|
||||
inline std::pair<int, int> decompose_hadamard(int n) {
|
||||
// n = m*2^k
|
||||
int m = 1;
|
||||
if (!is_power_of_2(n)) {
|
||||
auto h_matrices = hadamard_matrices();
|
||||
for (auto [factor, _] : h_matrices) {
|
||||
if (n % factor == 0) {
|
||||
m = factor;
|
||||
n /= factor;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (m == 1) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||
}
|
||||
}
|
||||
return {n, m};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -10,9 +10,106 @@
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
// Wrapper to account for differences in
|
||||
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
|
||||
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
|
||||
int info;
|
||||
|
||||
#ifdef LAPACK_FORTRAN_STRLEN_END
|
||||
strtri_(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info,
|
||||
/* uplo_len = */ static_cast<size_t>(1),
|
||||
/* diag_len = */ static_cast<size_t>(1));
|
||||
#else
|
||||
strtri_(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
#endif
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void inverse_impl(const array& a, array& inv) {
|
||||
void general_inv(array& inv, int N, int i) {
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
// Compute LU factorization.
|
||||
sgetrf_(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
static const int lwork_query = -1;
|
||||
float workspace_size = 0;
|
||||
|
||||
// Compute workspace size.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ nullptr,
|
||||
/* work = */ &workspace_size,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
void tri_inv(array& inv, int N, int i, bool upper) {
|
||||
const char uplo = upper ? 'L' : 'U';
|
||||
const char diag = 'N';
|
||||
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: triangular inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the following
|
||||
// identity to avoid transposing (see
|
||||
// https://math.stackexchange.com/a/340234):
|
||||
@@ -24,63 +121,11 @@ void inverse_impl(const array& a, array& inv) {
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute LU factorization.
|
||||
sgetrf_(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
static const int lwork_query = -1;
|
||||
float workspace_size = 0;
|
||||
|
||||
// Compute workspace size.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ nullptr,
|
||||
/* work = */ &workspace_size,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch =
|
||||
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
if (tri) {
|
||||
tri_inv(inv, N, i, upper);
|
||||
} else {
|
||||
general_inv(inv, N, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -89,7 +134,7 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
||||
}
|
||||
inverse_impl(inputs[0], output);
|
||||
inverse_impl(inputs[0], output, tri_, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -5,11 +5,9 @@
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/io/load.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <const uint8_t scalar_size>
|
||||
@@ -29,12 +27,14 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
namespace mlx::core {
|
||||
|
||||
reader_->seek(offset_, std::ios_base::beg);
|
||||
reader_->read(out.data<char>(), out.nbytes());
|
||||
void load(
|
||||
array& out,
|
||||
size_t offset,
|
||||
const std::shared_ptr<io::Reader>& reader,
|
||||
bool swap_endianness_) {
|
||||
reader->read(out.data<char>(), out.nbytes(), offset);
|
||||
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
@@ -51,4 +51,11 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
load(out, offset_, reader_, swap_endianness_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
14
mlx/backend/common/load.h
Normal file
14
mlx/backend/common/load.h
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/io/load.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void load(
|
||||
array& out,
|
||||
size_t offset,
|
||||
const std::shared_ptr<io::Reader>& reader,
|
||||
bool swap_endianess);
|
||||
|
||||
} // namespace mlx::core
|
@@ -21,13 +21,14 @@ EOM
|
||||
|
||||
fi
|
||||
|
||||
CONTENT=$($GCC -I $SRCDIR -E $SRCDIR/mlx/backend/common/compiled_preamble.h 2>/dev/null)
|
||||
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
const char* get_kernel_preamble() {
|
||||
return R"preamble(
|
||||
$INCLUDES
|
||||
$CONTENT
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::detail;
|
||||
)preamble";
|
||||
}
|
||||
|
@@ -108,105 +108,105 @@ struct Abs {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::abs(x);
|
||||
};
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::acos(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::acosh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::asin(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::asinh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::atan(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return std::atan2(y, x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::atanh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::ceil(x);
|
||||
};
|
||||
}
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
@@ -219,35 +219,35 @@ struct Cos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::cos(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::cosh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(fast_erf(static_cast<float>(x)));
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return fast_exp(x);
|
||||
};
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return std::exp(x);
|
||||
@@ -258,83 +258,83 @@ struct Expm1 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return expm1(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::floor(x);
|
||||
};
|
||||
}
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log2(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log10(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return log1p(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return !x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return -x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
@@ -373,55 +373,59 @@ struct Sign {
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return x == complex64_t(0) ? x : x / std::abs(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sin(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sinh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x * x;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sqrt(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::tan(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::tanh(x);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Add {
|
||||
@@ -554,7 +558,7 @@ struct LogAddExp {
|
||||
? maxval
|
||||
: static_cast<decltype(x)>(
|
||||
maxval + std::log1p(fast_exp(minval - maxval)));
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
@@ -602,14 +606,14 @@ struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct Select {
|
||||
@@ -623,35 +627,35 @@ struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
@@ -8,9 +8,9 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/arange.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/threefry.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
@@ -313,20 +313,6 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
||||
unary(in, out, detail::LogicalNot());
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -419,7 +405,17 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto out_strides = make_contiguous_strides<size_t>(in.shape());
|
||||
copy_inplace<size_t>(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
out_strides,
|
||||
0,
|
||||
0,
|
||||
CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
@@ -492,7 +488,8 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
|
||||
auto [copy_needed, data_offset, inp_strides] =
|
||||
prepare_slice(in, start_indices_, strides_);
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
@@ -508,8 +505,16 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General);
|
||||
} else {
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, out);
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -590,4 +595,43 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto ibytes = size_of(in.dtype());
|
||||
auto obytes = size_of(out.dtype());
|
||||
// Conditions for buffer copying (disjunction):
|
||||
// - type size is the same
|
||||
// - type size is smaller and the last axis is contiguous
|
||||
// - the entire array is row contiguous
|
||||
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < strides.size() - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
out.copy_shared_buffer(
|
||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||
} else {
|
||||
auto tmp = array(
|
||||
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
if (in.dtype() == bool_) {
|
||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||
in_tmp.copy_shared_buffer(in);
|
||||
copy_inplace(in_tmp, tmp, CopyType::General);
|
||||
} else {
|
||||
copy_inplace(in, tmp, CopyType::General);
|
||||
}
|
||||
|
||||
auto flags = out.flags();
|
||||
flags.contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -87,6 +87,38 @@ struct OrReduce {
|
||||
}
|
||||
};
|
||||
|
||||
struct MaxReduce {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
if (std::isnan(x)) {
|
||||
*y = x;
|
||||
} else {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct MinReduce {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
if (std::isnan(x)) {
|
||||
*y = x;
|
||||
} else {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_out(
|
||||
const array& in,
|
||||
@@ -104,63 +136,27 @@ void reduce_dispatch_out(
|
||||
}
|
||||
case Reduce::Sum: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
reduction_op<InT, bool>(in, out, axes, false, op);
|
||||
break;
|
||||
case uint8:
|
||||
reduction_op<InT, uint8_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint16:
|
||||
reduction_op<InT, uint16_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint32:
|
||||
reduction_op<InT, uint32_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint64:
|
||||
reduction_op<InT, uint64_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int8:
|
||||
reduction_op<InT, int8_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int16:
|
||||
reduction_op<InT, int16_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int32:
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int64:
|
||||
reduction_op<InT, int64_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case float16:
|
||||
reduction_op<InT, float16_t>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case float32:
|
||||
reduction_op<InT, float>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduction_op<InT, bfloat16_t>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case complex64:
|
||||
reduction_op<InT, complex64_t>(in, out, axes, complex64_t{0.0f}, op);
|
||||
break;
|
||||
if (out.dtype() == int32) {
|
||||
// special case since the input type can be bool
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
}
|
||||
} break;
|
||||
break;
|
||||
}
|
||||
case Reduce::Prod: {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
reduction_op<InT, InT>(in, out, axes, 1, op);
|
||||
break;
|
||||
}
|
||||
case Reduce::Max: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, op);
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Min: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, op);
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -168,6 +164,29 @@ void reduce_dispatch_out(
|
||||
|
||||
} // namespace
|
||||
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
|
||||
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
@@ -49,47 +49,18 @@ struct ReductionPlan {
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
namespace {
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
// Should this be in utils?
|
||||
inline void nd_loop(
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
const std::vector<size_t>& strides);
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
std::vector<int> shape = x.shape();
|
||||
std::vector<size_t> strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
strides.erase(strides.begin() + a);
|
||||
}
|
||||
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
const std::vector<int>& axes);
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultStridedReduce {
|
||||
@@ -123,102 +94,6 @@ struct DefaultContiguousReduce {
|
||||
}
|
||||
};
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
std::vector<int> shape = {x.shape(axes[0])};
|
||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1]) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
} else {
|
||||
shape.push_back(x.shape(axes[i]));
|
||||
strides.push_back(x.strides()[axes[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
return ReductionPlan(ContiguousStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
// Let's check if we can optimize our access patterns
|
||||
//
|
||||
// 1. We have a reduction axis with stride 1. Simply call
|
||||
// GeneralContiguousReduce and be done with it.
|
||||
// 2. We have transpositions and we are not reducing over the axis with
|
||||
// stride 1. However, we are reducing over an axis where everything is
|
||||
// contiguous in memory to the right of that axis. We can call strided
|
||||
// reduce and be done with it.
|
||||
// 2. We have weird transpositions and expands. Copy the strides to the
|
||||
// output, then call strided reduce.
|
||||
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
return a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
for (int i = reductions.size() - 1; i >= 1; i--) {
|
||||
auto a = reductions[i];
|
||||
auto b = reductions[i - 1];
|
||||
|
||||
// b.stride = a.shape * a.stride then a and b are contiguous
|
||||
if (b.second == a.first * a.second) {
|
||||
reductions.erase(reductions.begin() + i);
|
||||
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
}
|
||||
|
||||
// We can call the contiguous reduction op for every weird way the input is
|
||||
// structured in the rest of the axes.
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(GeneralContiguousReduce, shape, strides);
|
||||
}
|
||||
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int size = 1;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
if (x.strides()[i] != size) {
|
||||
break;
|
||||
}
|
||||
size *= x.shape(i);
|
||||
}
|
||||
if (size >= strides.back()) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
return ReductionPlan(GeneralReduce, shape, strides);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename OpS, typename OpC, typename Op>
|
||||
void reduction_op(
|
||||
const array& x,
|
||||
@@ -361,6 +236,4 @@ void reduction_op(
|
||||
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
147
mlx/backend/common/reduce_utils.cpp
Normal file
147
mlx/backend/common/reduce_utils.cpp
Normal file
@@ -0,0 +1,147 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
std::vector<int> shape = x.shape();
|
||||
std::vector<size_t> strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
strides.erase(strides.begin() + a);
|
||||
}
|
||||
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
std::vector<int> shape = {x.shape(axes[0])};
|
||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
} else {
|
||||
shape.push_back(x.shape(axes[i]));
|
||||
strides.push_back(x.strides()[axes[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove singleton axes from the plan
|
||||
for (int i = shape.size() - 1; i >= 0; i--) {
|
||||
if (shape[i] == 1) {
|
||||
shape.erase(shape.begin() + i);
|
||||
strides.erase(strides.begin() + i);
|
||||
}
|
||||
}
|
||||
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
return ReductionPlan(ContiguousStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
// Let's check if we can optimize our access patterns
|
||||
//
|
||||
// 1. We have a reduction axis with stride 1. Simply call
|
||||
// GeneralContiguousReduce and be done with it.
|
||||
// 2. We have transpositions and we are not reducing over the axis with
|
||||
// stride 1. However, we are reducing over an axis where everything is
|
||||
// contiguous in memory to the right of that axis. We can call strided
|
||||
// reduce and be done with it.
|
||||
// 2. We have weird transpositions and expands. Copy the strides to the
|
||||
// output, then call strided reduce.
|
||||
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
if (x.shape(a) > 1) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
bool a_is_zero = a.second == 0;
|
||||
bool b_is_zero = b.second == 0;
|
||||
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
for (int i = reductions.size() - 1; i >= 1; i--) {
|
||||
auto a = reductions[i];
|
||||
auto b = reductions[i - 1];
|
||||
|
||||
// b.stride = a.shape * a.stride then a and b are contiguous
|
||||
if (b.second == a.first * a.second) {
|
||||
reductions.erase(reductions.begin() + i);
|
||||
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
}
|
||||
|
||||
// We can call the contiguous reduction op for every weird way the input is
|
||||
// structured in the rest of the axes.
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(GeneralContiguousReduce, shape, strides);
|
||||
}
|
||||
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int size = 1;
|
||||
bool have_expand = false;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t stride_i = x.strides()[i];
|
||||
int shape_i = x.shape(i);
|
||||
if (stride_i == 0) {
|
||||
if (shape_i == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
have_expand = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (stride_i != size && shape_i != 1) {
|
||||
break;
|
||||
}
|
||||
size *= shape_i;
|
||||
}
|
||||
// In the case of an expanded dimension we are being conservative and
|
||||
// require the smallest reduction stride to be smaller than the maximum row
|
||||
// contiguous size. The reason is that we can't easily know if the reduced
|
||||
// axis is before or after an expanded dimension.
|
||||
if (size > strides.back() || (size == strides.back() && !have_expand)) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
return ReductionPlan(GeneralReduce, shape, strides);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -234,7 +234,7 @@ void scan_dispatch(
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
: std::numeric_limits<U>::min();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||
|
40
mlx/backend/common/slicing.cpp
Normal file
40
mlx/backend/common/slicing.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
copy_needed |= strides[i] < 0;
|
||||
}
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [no_bsx_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
flags.contiguous = (no_bsx_size == data_size);
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
21
mlx/backend/common/slicing.h
Normal file
21
mlx/backend/common/slicing.h
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& strides);
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
size_t data_size,
|
||||
array& out);
|
||||
|
||||
} // namespace mlx::core
|
@@ -113,14 +113,14 @@ void sort(const array& in, array& out, int axis) {
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
auto remaining_strides = out.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
size_t axis_stride = out.strides()[axis];
|
||||
int axis_size = out.shape(axis);
|
||||
|
||||
// Perform sorting in place
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
@@ -143,34 +143,42 @@ void argsort(const array& in, array& out, int axis) {
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
auto in_remaining_shape = in.shape();
|
||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
auto in_remaining_strides = in.strides();
|
||||
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
auto out_remaining_shape = out.shape();
|
||||
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
|
||||
|
||||
auto out_remaining_strides = out.strides();
|
||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
||||
|
||||
size_t in_stride = in.strides()[axis];
|
||||
size_t out_stride = out.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
// Perform sorting
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
const T* data_ptr = in.data<T>() + loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + loc;
|
||||
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
|
||||
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
|
||||
const T* data_ptr = in.data<T>() + in_loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
|
||||
|
||||
StridedIterator st_(idx_ptr, axis_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, axis_stride, 0);
|
||||
StridedIterator ed(idx_ptr, axis_stride, axis_size);
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * axis_stride];
|
||||
auto v2 = data_ptr[b * axis_stride];
|
||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
|
@@ -12,6 +12,7 @@ namespace {
|
||||
// TODO: Add support for more combinations of input types.
|
||||
enum class TernaryOpType {
|
||||
ScalarScalarScalar,
|
||||
VectorVectorVector,
|
||||
General,
|
||||
};
|
||||
|
||||
@@ -20,6 +21,12 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
|
||||
TernaryOpType topt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
|
||||
topt = TernaryOpType::ScalarScalarScalar;
|
||||
} else if (
|
||||
(a.flags().row_contiguous && b.flags().row_contiguous &&
|
||||
c.flags().row_contiguous) ||
|
||||
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
||||
c.flags().col_contiguous)) {
|
||||
topt = TernaryOpType::VectorVectorVector;
|
||||
} else {
|
||||
topt = TernaryOpType::General;
|
||||
}
|
||||
@@ -33,11 +40,32 @@ void set_ternary_op_output_data(
|
||||
array& out,
|
||||
TernaryOpType topt,
|
||||
bool donate_with_move = false) {
|
||||
auto maybe_donate = [&out, donate_with_move](const array& x) {
|
||||
if (is_donatable(x, out)) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(x);
|
||||
} else {
|
||||
out.copy_shared_buffer(x);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
switch (topt) {
|
||||
case TernaryOpType::ScalarScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
||||
break;
|
||||
case TernaryOpType::VectorVectorVector:
|
||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
}
|
||||
break;
|
||||
case TernaryOpType::General:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
break;
|
||||
|
@@ -12,7 +12,7 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
void set_unary_output_data(const array& in, array& out) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
|
@@ -29,6 +29,15 @@ inline size_t elem_to_loc(int elem, const array& a) {
|
||||
return elem_to_loc(elem, a.shape(), a.strides());
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
|
||||
std::vector<stride_t> strides(shape.size(), 1);
|
||||
for (int i = shape.size() - 1; i > 0; i--) {
|
||||
strides[i - 1] = strides[i] * shape[i];
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
// Collapse dims that are contiguous to possibly route to a better kernel
|
||||
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
||||
// should return {{2, 4}, {{1, 2}}}.
|
||||
@@ -95,27 +104,62 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||
}
|
||||
|
||||
// The single array version of the above.
|
||||
inline std::tuple<std::vector<int>, std::vector<size_t>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
std::vector<int> collapsed_shape;
|
||||
std::vector<size_t> collapsed_strides;
|
||||
|
||||
if (shape.size() > 0) {
|
||||
collapsed_shape.push_back(shape[0]);
|
||||
collapsed_strides.push_back(strides[0]);
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
if (strides[i] * shape[i] != collapsed_strides.back() ||
|
||||
collapsed_shape.back() * static_cast<size_t>(shape[i]) >
|
||||
std::numeric_limits<int>::max()) {
|
||||
collapsed_shape.push_back(shape[i]);
|
||||
collapsed_strides.push_back(strides[i]);
|
||||
} else {
|
||||
collapsed_shape.back() *= shape[i];
|
||||
collapsed_strides.back() = strides[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(collapsed_shape, collapsed_strides);
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
inline auto check_contiguity(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<stride_t>& strides) {
|
||||
size_t data_size = 1;
|
||||
size_t no_broadcast_data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
bool is_row_contiguous = true;
|
||||
bool is_col_contiguous = true;
|
||||
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
is_col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
data_size *= shape[i];
|
||||
no_broadcast_data_size *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
|
||||
return std::make_tuple(
|
||||
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
|
||||
}
|
||||
|
||||
inline bool is_donatable(const array& in, const array& out) {
|
||||
constexpr size_t donation_extra = 16384;
|
||||
|
||||
return in.is_donatable() && in.itemsize() == out.itemsize() &&
|
||||
in.buffer_size() <= out.nbytes() + donation_extra;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -18,7 +18,7 @@ function(make_jit_source SRC_FILE)
|
||||
${CMAKE_C_COMPILER}
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${SRC_FILE}
|
||||
"-D${MLX_METAL_VERSION}"
|
||||
"-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
kernels/${SRC_FILE}.h
|
||||
${ARGN}
|
||||
@@ -52,6 +52,7 @@ make_jit_source(
|
||||
)
|
||||
make_jit_source(scatter)
|
||||
make_jit_source(gather)
|
||||
make_jit_source(hadamard)
|
||||
|
||||
if (MLX_METAL_JIT)
|
||||
target_sources(
|
||||
@@ -64,6 +65,11 @@ if (MLX_METAL_JIT)
|
||||
make_jit_source(unary)
|
||||
make_jit_source(binary)
|
||||
make_jit_source(binary_two)
|
||||
make_jit_source(
|
||||
fft
|
||||
kernels/fft/radix.h
|
||||
kernels/fft/readwrite.h
|
||||
)
|
||||
make_jit_source(ternary)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
@@ -73,6 +79,7 @@ if (MLX_METAL_JIT)
|
||||
kernels/reduction/reduce_all.h
|
||||
kernels/reduction/reduce_col.h
|
||||
kernels/reduction/reduce_row.h
|
||||
kernels/reduction/reduce_init.h
|
||||
)
|
||||
make_jit_source(
|
||||
steel/gemm/gemm
|
||||
@@ -107,6 +114,8 @@ if (MLX_METAL_JIT)
|
||||
kernels/steel/defines.h
|
||||
kernels/steel/conv/loaders/loader_general.h
|
||||
)
|
||||
make_jit_source(quantized)
|
||||
make_jit_source(gemv_masked)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
@@ -123,9 +132,12 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
@@ -135,11 +147,13 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_PATH)
|
||||
|
@@ -241,9 +241,22 @@ void MetalAllocator::free(Buffer buffer) {
|
||||
}
|
||||
}
|
||||
|
||||
size_t MetalAllocator::size(Buffer buffer) const {
|
||||
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
|
||||
}
|
||||
|
||||
MetalAllocator& allocator() {
|
||||
static MetalAllocator allocator_;
|
||||
return allocator_;
|
||||
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
|
||||
// not be called on exit and all the buffers will be leaked. This is necessary
|
||||
// because releasing buffers can take more than 30sec when the program holds a
|
||||
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
|
||||
// users when exiting.
|
||||
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
|
||||
// when applying this pattern to more places, or when introducing sanitizers
|
||||
// to MLX.
|
||||
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
|
||||
static MetalAllocator* allocator_ = new MetalAllocator;
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
|
@@ -56,6 +56,7 @@ class MetalAllocator : public allocator::Allocator {
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
size_t get_active_memory() {
|
||||
return active_memory_;
|
||||
};
|
||||
|
@@ -6,20 +6,62 @@
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this)); \
|
||||
}
|
||||
|
||||
#define BINARY_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
binary_op_gpu(inputs, outputs, get_primitive_string(this)); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
|
||||
void binary_op(
|
||||
std::string get_kernel_name(
|
||||
BinaryOpType bopt,
|
||||
const std::string& op,
|
||||
const array& a,
|
||||
bool use_2d,
|
||||
int ndim) {
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << (use_2d ? "sv2" : "sv");
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << (use_2d ? "vs2" : "vs");
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << (use_2d ? "vv2" : "vv");
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << ndim;
|
||||
} else {
|
||||
kname << "n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
return kname.str();
|
||||
}
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 2);
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
@@ -32,39 +74,12 @@ void binary_op(
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else {
|
||||
kname << "n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
|
||||
auto kernel =
|
||||
get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
@@ -108,9 +123,11 @@ void binary_op(
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims = use_2d
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
@@ -120,15 +137,36 @@ void binary_op(
|
||||
}
|
||||
}
|
||||
|
||||
void binary_op(
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||
binary_op_gpu_inplace(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op) {
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_op_gpu(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -139,39 +177,11 @@ void binary_op(
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else {
|
||||
kname << "n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_kernel(d, kernel_name, a, out);
|
||||
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
@@ -208,10 +218,11 @@ void binary_op(
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads =
|
||||
bopt == BinaryOpType::General ? out.size() : out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
// Launch a 1D or 2D grid of threads
|
||||
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
@@ -221,102 +232,65 @@ void binary_op(
|
||||
}
|
||||
}
|
||||
|
||||
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "add");
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt, true);
|
||||
binary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "arctan2");
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op) {
|
||||
auto& s = out.primitive().stream();
|
||||
binary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU_MULTI(DivMod)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Equal)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
BINARY_GPU(LessEqual)
|
||||
BINARY_GPU(LogicalAnd)
|
||||
BINARY_GPU(LogicalOr)
|
||||
BINARY_GPU(LogAddExp)
|
||||
BINARY_GPU(Maximum)
|
||||
BINARY_GPU(Minimum)
|
||||
BINARY_GPU(Multiply)
|
||||
BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op(inputs, out, "bitwise_and");
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op(inputs, out, "bitwise_or");
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op(inputs, out, "bitwise_xor");
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op(inputs, out, "left_shift");
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op(inputs, out, "right_shift");
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "div");
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
binary_op(inputs, outputs, "divmod");
|
||||
}
|
||||
|
||||
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "rem");
|
||||
}
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||
}
|
||||
|
||||
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "ge");
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "geq");
|
||||
}
|
||||
|
||||
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "le");
|
||||
}
|
||||
|
||||
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "leq");
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "land");
|
||||
}
|
||||
|
||||
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lor");
|
||||
}
|
||||
|
||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lae");
|
||||
}
|
||||
|
||||
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "max");
|
||||
}
|
||||
|
||||
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "min");
|
||||
}
|
||||
|
||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
}
|
||||
|
||||
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "neq");
|
||||
}
|
||||
|
||||
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "pow");
|
||||
}
|
||||
|
||||
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "sub");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
33
mlx/backend/metal/binary.h
Normal file
33
mlx/backend/metal/binary.h
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
@@ -56,12 +56,15 @@ inline void build_kernel(
|
||||
} else {
|
||||
add_indices = true;
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl
|
||||
<< " constant const size_t* " << xname << "_strides [[buffer("
|
||||
<< cnt++ << ")]]," << std::endl;
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (add_indices) {
|
||||
os << " constant const size_t* in_strides [[buffer(" << cnt++
|
||||
<< ")]],\n";
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
os << " device " << get_type_string(x.dtype()) << "* "
|
||||
@@ -110,13 +113,17 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (auto& x : inputs) {
|
||||
int nc_in_count = 0;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
auto type_str = get_type_string(x.dtype());
|
||||
os << " auto tmp_" << xname << " = static_cast<"
|
||||
<< get_type_string(x.dtype()) << ">(";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
os << ");" << std::endl;
|
||||
} else if (is_scalar(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[0];" << std::endl;
|
||||
@@ -124,17 +131,20 @@ inline void build_kernel(
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index];" << std::endl;
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = nc_in_count * ndim;
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[";
|
||||
os << "index_0 * " << xname << "_strides[0]";
|
||||
os << "index_0 * " << "in_strides[" << offset << "]";
|
||||
for (int i = 1; i < ndim; i++) {
|
||||
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
|
||||
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
|
||||
}
|
||||
os << "];" << std::endl;
|
||||
nc_in_count++;
|
||||
} else {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[elem_to_loc(index, output_shape, " << xname
|
||||
<< "_strides, ndim)];" << std::endl;
|
||||
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
|
||||
<< nc_in_count * ndim << ", ndim)];" << std::endl;
|
||||
nc_in_count++;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,6 +306,7 @@ void Compiled::eval_gpu(
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
std::vector<size_t> in_strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
@@ -303,13 +314,17 @@ void Compiled::eval_gpu(
|
||||
auto& x = inputs[i];
|
||||
compute_encoder.set_input_array(x, cnt++);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
strides[stride_idx].size() * sizeof(size_t),
|
||||
cnt++);
|
||||
in_strides.insert(
|
||||
in_strides.end(),
|
||||
strides[stride_idx].begin(),
|
||||
strides[stride_idx].end());
|
||||
stride_idx++;
|
||||
}
|
||||
}
|
||||
if (!in_strides.empty()) {
|
||||
compute_encoder->setBytes(
|
||||
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, true);
|
||||
|
@@ -552,7 +552,7 @@ void winograd_conv_2D_gpu(
|
||||
|
||||
// Fill with zeros
|
||||
array zero_arr = array(0, in.dtype());
|
||||
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
|
||||
fill_gpu(zero_arr, in_padded, s);
|
||||
copies_w.push_back(zero_arr);
|
||||
|
||||
// Pick input slice from padded
|
||||
@@ -571,7 +571,6 @@ void winograd_conv_2D_gpu(
|
||||
|
||||
copies_w.push_back(in_padded_slice);
|
||||
copies_w.push_back(in_padded);
|
||||
copies_w.push_back(zero_arr);
|
||||
|
||||
MLXConvParams<2> conv_params_updated{
|
||||
/* const int N = */ in_padded.shape(0),
|
||||
@@ -911,7 +910,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Throw error
|
||||
else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval_gpu] Only supports 1D or 2D convolutions.");
|
||||
"[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
|
||||
}
|
||||
|
||||
// Clear copies
|
||||
|
@@ -33,9 +33,6 @@ void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
@@ -57,22 +54,27 @@ void copy_gpu_inplace(
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector{strides_in_pre, strides_out_pre});
|
||||
auto& strides_in_ = strides[0];
|
||||
auto& strides_out_ = strides[1];
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << "s";
|
||||
kname << (use_2d ? "s2" : "s");
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << "v";
|
||||
kname << (use_2d ? "v2" : "v");
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "g";
|
||||
@@ -138,7 +140,8 @@ void copy_gpu_inplace(
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
@@ -153,6 +156,7 @@ void copy_gpu_inplace(
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
@@ -164,9 +168,37 @@ void copy_gpu_inplace(
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" +
|
||||
type_to_name(val) + type_to_name(out);
|
||||
auto kernel = get_copy_kernel(d, kernel_name, val, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(val, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -37,4 +37,7 @@ void copy_gpu_inplace(
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
// Fill the output with the scalar val
|
||||
void fill_gpu(const array& val, array& out, const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
88
mlx/backend/metal/custom_kernel.cpp
Normal file
88
mlx/backend/metal/custom_kernel.cpp
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void CustomKernel::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
for (auto& out : outputs) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (init_value_) {
|
||||
copies.emplace_back(init_value_.value(), out.dtype());
|
||||
fill_gpu(copies.back(), out, s);
|
||||
}
|
||||
}
|
||||
|
||||
auto check_input = [&copies, &s, this](const array& x) -> const array {
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (!ensure_row_contiguous_ || no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
std::vector<const array> checked_inputs;
|
||||
for (const array& in : inputs) {
|
||||
checked_inputs.push_back(check_input(in));
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = name_;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
lib = d.get_library(lib_name, metal::utils() + source_);
|
||||
}
|
||||
auto kernel = d.get_kernel(name_, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
int index = 0;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
auto shape_info = shape_infos_[i];
|
||||
compute_encoder.set_input_array(in, index);
|
||||
index++;
|
||||
if (in.ndim() > 0) {
|
||||
int ndim = in.ndim();
|
||||
if (shape_info.shape) {
|
||||
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), index);
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (array out : outputs) {
|
||||
compute_encoder.set_output_array(out, index);
|
||||
index++;
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
MTL::Size group_dims = MTL::Size(tx, ty, tz);
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -1,8 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
|
||||
#include <sys/sysctl.h>
|
||||
@@ -14,11 +12,8 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
namespace {
|
||||
@@ -30,7 +25,9 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
constexpr auto get_metal_version() {
|
||||
#if defined METAL_3_1
|
||||
#if (MLX_METAL_VERSION >= 320)
|
||||
return MTL::LanguageVersion3_2;
|
||||
#elif (MLX_METAL_VERSION >= 310)
|
||||
return MTL::LanguageVersion3_1;
|
||||
#else
|
||||
return MTL::LanguageVersion3_0;
|
||||
@@ -124,6 +121,49 @@ MTL::Library* load_library(
|
||||
|
||||
} // namespace
|
||||
|
||||
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
}
|
||||
|
||||
CommandEncoder::~CommandEncoder() {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
}
|
||||
|
||||
void CommandEncoder::set_input_array(
|
||||
const array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
// Insert a barrier
|
||||
enc->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs.erase(it);
|
||||
}
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_output_array(
|
||||
array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent) {
|
||||
concurrent_outputs.insert(buf);
|
||||
} else {
|
||||
outputs.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
@@ -253,23 +293,13 @@ void Device::register_library(
|
||||
}
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::function<std::string(const std::string&)>& lib_path_func) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
std::string new_lib_path = lib_path_func(lib_name);
|
||||
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
|
||||
library_map_.insert({lib_name, new_lib});
|
||||
}
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib;
|
||||
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
|
||||
mtl_lib = it->second;
|
||||
} else { // Look for metallib alongside library
|
||||
register_library(lib_name);
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
mtl_lib = library_map_[lib_name];
|
||||
}
|
||||
|
||||
|
@@ -3,15 +3,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <Metal/Metal.hpp>
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
|
||||
@@ -19,6 +18,8 @@ namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
// Note, this function must be left inline in a header so that it is not
|
||||
// dynamically linked.
|
||||
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
@@ -37,10 +38,7 @@ using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
};
|
||||
CommandEncoder(MTL::CommandBuffer* cbuf);
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
@@ -63,34 +61,8 @@ struct CommandEncoder {
|
||||
return enc;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0) {
|
||||
auto r_buf =
|
||||
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
// Insert a barrier
|
||||
enc->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs.erase(it);
|
||||
}
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent) {
|
||||
concurrent_outputs.insert(buf);
|
||||
} else {
|
||||
outputs.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
|
||||
@@ -98,10 +70,7 @@ struct CommandEncoder {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
|
||||
~CommandEncoder() {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
}
|
||||
~CommandEncoder();
|
||||
|
||||
private:
|
||||
void maybe_split();
|
||||
@@ -136,10 +105,14 @@ class Device {
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path);
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::function<std::string(const std::string&)>& lib_path_func =
|
||||
get_colocated_mtllib_path);
|
||||
|
||||
// Note, this should remain in the header so that it is not dynamically
|
||||
// linked
|
||||
void register_library(const std::string& lib_name) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
}
|
||||
}
|
||||
|
||||
MTL::Library* get_library(const std::string& name);
|
||||
|
||||
|
142
mlx/backend/metal/distributed.cpp
Normal file
142
mlx/backend/metal/distributed.cpp
Normal file
@@ -0,0 +1,142 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
void signal_and_wait(const array& in, const array& out, const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
if (in.event().valid()) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
||||
in.event().value());
|
||||
}
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
}
|
||||
|
||||
void AllReduce::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto task = [in = in,
|
||||
out = out,
|
||||
reduce_type = reduce_type_,
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
switch (reduce_type) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(
|
||||
group, in.data_shared_ptr() == nullptr ? out : in, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
signal_and_wait(in, out, stream());
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto task = [in = in, out = out, group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
signal_and_wait(in, out, stream());
|
||||
}
|
||||
|
||||
void Send::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Schedule an async send on the comm stream
|
||||
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::send(group, in, dst);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
// Encode a signal event for the input but not a wait since we don't need to
|
||||
// wait on the output.
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
if (in.event().valid()) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
||||
in.event().value());
|
||||
}
|
||||
}
|
||||
|
||||
void Recv::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& out = outputs[0];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Schedule an async recv on the comm stream
|
||||
auto task = [out = out, group = group(), src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
// Encode a wait event as there is no input for the recv to encode a signal.
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
@@ -1,106 +1,803 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
#include <complex>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/backend/metal/binary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/unary.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
|
||||
|
||||
auto& in = inputs[0];
|
||||
#define MAX_STOCKHAM_FFT_SIZE 4096
|
||||
#define MAX_RADER_FFT_SIZE 2048
|
||||
#define MAX_BLUESTEIN_FFT_SIZE 2048
|
||||
// Threadgroup memory batching improves throughput for small n
|
||||
#define MIN_THREADGROUP_MEM_SIZE 256
|
||||
// For strided reads/writes, coalesce at least this many complex64s
|
||||
#define MIN_COALESCE_WIDTH 4
|
||||
|
||||
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
|
||||
in.dtype() != complex64 || out.dtype() != complex64) {
|
||||
// Could also fallback to CPU implementation here.
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
|
||||
inline const std::vector<int> supported_radices() {
|
||||
// Ordered by preference in decomposition.
|
||||
return {13, 11, 8, 7, 6, 5, 4, 3, 2};
|
||||
}
|
||||
|
||||
std::vector<int> prime_factors(int n) {
|
||||
int z = 2;
|
||||
std::vector<int> factors;
|
||||
while (z * z <= n) {
|
||||
if (n % z == 0) {
|
||||
factors.push_back(z);
|
||||
n /= z;
|
||||
} else {
|
||||
z++;
|
||||
}
|
||||
}
|
||||
if (n > 1) {
|
||||
factors.push_back(n);
|
||||
}
|
||||
return factors;
|
||||
}
|
||||
|
||||
struct FourStepParams {
|
||||
bool required = false;
|
||||
bool first_step = true;
|
||||
int n1 = 0;
|
||||
int n2 = 0;
|
||||
};
|
||||
|
||||
// Forward Declaration
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
const Stream& s);
|
||||
|
||||
struct FFTPlan {
|
||||
int n = 0;
|
||||
// Number of steps for each radix in the Stockham decomposition
|
||||
std::vector<int> stockham;
|
||||
// Number of steps for each radix in the Rader decomposition
|
||||
std::vector<int> rader;
|
||||
// Rader factor, 1 if no rader factors
|
||||
int rader_n = 1;
|
||||
int bluestein_n = -1;
|
||||
// Four step FFT
|
||||
bool four_step = false;
|
||||
int n1 = 0;
|
||||
int n2 = 0;
|
||||
};
|
||||
|
||||
int next_fast_n(int n) {
|
||||
return next_power_of_2(n);
|
||||
}
|
||||
|
||||
std::vector<int> plan_stockham_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::vector<int> plan(radices.size(), 0);
|
||||
int orig_n = n;
|
||||
if (n == 1) {
|
||||
return plan;
|
||||
}
|
||||
for (int i = 0; i < radices.size(); i++) {
|
||||
int radix = radices[i];
|
||||
// Manually tuned radices for powers of 2
|
||||
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
||||
continue;
|
||||
}
|
||||
while (n % radix == 0) {
|
||||
plan[i] += 1;
|
||||
n /= radix;
|
||||
if (n == 1) {
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Unplannable");
|
||||
}
|
||||
|
||||
FFTPlan plan_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::set<int> radices_set(radices.begin(), radices.end());
|
||||
|
||||
FFTPlan plan;
|
||||
plan.n = n;
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
auto factors = prime_factors(n);
|
||||
int remaining_n = n;
|
||||
|
||||
// Four Step FFT when N is too large for shared mem.
|
||||
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||
// For power's of two we have a fast, no transpose four step implementation.
|
||||
plan.four_step = true;
|
||||
// Rough heuristic for choosing faster powers of two when we can
|
||||
plan.n2 = n > 65536 ? 1024 : 64;
|
||||
plan.n1 = n / plan.n2;
|
||||
return plan;
|
||||
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
// Otherwise we use a multi-upload Bluestein's
|
||||
plan.four_step = true;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
return plan;
|
||||
}
|
||||
|
||||
size_t n = in.shape(axes_[0]);
|
||||
for (int factor : factors) {
|
||||
// Make sure the factor is a supported radix
|
||||
if (radices_set.find(factor) == radices_set.end()) {
|
||||
// We only support a single Rader factor currently
|
||||
// TODO(alexbarron) investigate weirdness with large
|
||||
// Rader sizes -- possibly a compiler issue?
|
||||
if (plan.rader_n > 1 || n > MAX_RADER_FFT_SIZE) {
|
||||
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
return plan;
|
||||
}
|
||||
// See if we can use Rader's algorithm to Stockham decompose n - 1
|
||||
auto rader_factors = prime_factors(factor - 1);
|
||||
int last_factor = -1;
|
||||
for (int rf : rader_factors) {
|
||||
// We don't nest Rader's algorithm so if `factor - 1`
|
||||
// isn't Stockham decomposable we give up and do Bluestein's.
|
||||
if (radices_set.find(rf) == radices_set.end()) {
|
||||
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
plan.rader = plan_stockham_fft(factor - 1);
|
||||
plan.rader_n = factor;
|
||||
remaining_n /= factor;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_power_of_2(n) || n > 2048 || n < 4) {
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
|
||||
plan.stockham = plan_stockham_fft(remaining_n);
|
||||
return plan;
|
||||
}
|
||||
|
||||
int compute_elems_per_thread(FFTPlan plan) {
|
||||
// Heuristics for selecting an efficient number
|
||||
// of threads to use for a particular mixed-radix FFT.
|
||||
auto n = plan.n;
|
||||
|
||||
std::vector<int> steps;
|
||||
auto radices = supported_radices();
|
||||
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
|
||||
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
|
||||
std::set<int> used_radices;
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
int radix = radices[i % radices.size()];
|
||||
if (steps[i] > 0) {
|
||||
used_radices.insert(radix);
|
||||
}
|
||||
}
|
||||
|
||||
// Manual tuning for 7/11/13
|
||||
if (used_radices.find(7) != used_radices.end() &&
|
||||
(used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end())) {
|
||||
return 7;
|
||||
} else if (
|
||||
used_radices.find(11) != used_radices.end() &&
|
||||
used_radices.find(13) != used_radices.end()) {
|
||||
return 11;
|
||||
}
|
||||
|
||||
// TODO(alexbarron) Some really weird stuff is going on
|
||||
// for certain `elems_per_thread` on large composite n.
|
||||
// Possibly a compiler issue?
|
||||
if (n == 3159)
|
||||
return 13;
|
||||
if (n == 3645)
|
||||
return 5;
|
||||
if (n == 3969)
|
||||
return 7;
|
||||
if (n == 1982)
|
||||
return 5;
|
||||
|
||||
if (used_radices.size() == 1) {
|
||||
return *(used_radices.begin());
|
||||
}
|
||||
if (used_radices.size() == 2) {
|
||||
if (used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end()) {
|
||||
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
|
||||
}
|
||||
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
|
||||
return radix_vec[1];
|
||||
}
|
||||
// In all other cases use the second smallest radix.
|
||||
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
|
||||
return radix_vec[1];
|
||||
}
|
||||
|
||||
// Rader
|
||||
int mod_exp(int x, int y, int n) {
|
||||
int out = 1;
|
||||
while (y) {
|
||||
if (y & 1) {
|
||||
out = out * x % n;
|
||||
}
|
||||
y >>= 1;
|
||||
x = x * x % n;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
int primitive_root(int n) {
|
||||
auto factors = prime_factors(n - 1);
|
||||
|
||||
for (int r = 2; r < n - 1; r++) {
|
||||
bool found = true;
|
||||
for (int factor : factors) {
|
||||
if (mod_exp(r, (n - 1) / factor, n) == 1) {
|
||||
found = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
return r;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> compute_raders_constants(
|
||||
int rader_n,
|
||||
const Stream& s) {
|
||||
int proot = primitive_root(rader_n);
|
||||
// Fermat's little theorem
|
||||
int inv = mod_exp(proot, rader_n - 2, rader_n);
|
||||
std::vector<short> g_q(rader_n - 1);
|
||||
std::vector<short> g_minus_q(rader_n - 1);
|
||||
for (int i = 0; i < rader_n - 1; i++) {
|
||||
g_q[i] = mod_exp(proot, i, rader_n);
|
||||
g_minus_q[i] = mod_exp(inv, i, rader_n);
|
||||
}
|
||||
array g_q_arr(g_q.begin(), {rader_n - 1});
|
||||
array g_minus_q_arr(g_minus_q.begin(), {rader_n - 1});
|
||||
|
||||
std::vector<std::complex<float>> b_q(rader_n - 1);
|
||||
for (int i = 0; i < rader_n - 1; i++) {
|
||||
float pi_i = (float)g_minus_q[i] * -2.0 * M_PI / rader_n;
|
||||
b_q[i] = std::exp(std::complex<float>(0, pi_i));
|
||||
}
|
||||
|
||||
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
|
||||
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
|
||||
auto b_q_fft_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
|
||||
std::ptrdiff_t item_size = b_q_fft.itemsize();
|
||||
size_t fft_size = rader_n - 1;
|
||||
// This FFT is always small (<4096, batch 1) so save some overhead
|
||||
// and do it on the CPU
|
||||
pocketfft::c2c(
|
||||
/* shape= */ {fft_size},
|
||||
/* stride_in= */ {item_size},
|
||||
/* stride_out= */ {item_size},
|
||||
/* axes= */ {0},
|
||||
/* forward= */ true,
|
||||
/* data_in= */ b_q.data(),
|
||||
/* data_out= */ b_q_fft_ptr,
|
||||
/* scale= */ 1.0f);
|
||||
return std::make_tuple(b_q_fft, g_q_arr, g_minus_q_arr);
|
||||
}
|
||||
|
||||
// Bluestein
|
||||
std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
|
||||
// We need to calculate the Bluestein twiddle factors
|
||||
// in double precision for the overall numerical stability
|
||||
// of Bluestein's FFT algorithm to be acceptable.
|
||||
//
|
||||
// Metal doesn't support float64, so instead we
|
||||
// manually implement the required operations on cpu.
|
||||
//
|
||||
// In numpy:
|
||||
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
|
||||
// w_q = np.fft.fft(1/w_k)
|
||||
// return w_k, w_q
|
||||
int length = 2 * n - 1;
|
||||
|
||||
std::vector<std::complex<float>> w_k_vec(n);
|
||||
std::vector<std::complex<float>> w_q_vec(bluestein_n, 0);
|
||||
|
||||
for (int i = -n + 1; i < n; i++) {
|
||||
double theta = pow(i, 2) * M_PI / (double)n;
|
||||
w_q_vec[i + n - 1] = std::exp(std::complex<double>(0, theta));
|
||||
if (i >= 0) {
|
||||
w_k_vec[i] = std::exp(std::complex<double>(0, -theta));
|
||||
}
|
||||
}
|
||||
|
||||
array w_k({n}, complex64, nullptr, {});
|
||||
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
|
||||
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
|
||||
|
||||
array w_q({bluestein_n}, complex64, nullptr, {});
|
||||
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
|
||||
auto w_q_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
|
||||
|
||||
std::ptrdiff_t item_size = w_q.itemsize();
|
||||
size_t fft_size = bluestein_n;
|
||||
pocketfft::c2c(
|
||||
/* shape= */ {fft_size},
|
||||
/* stride_in= */ {item_size},
|
||||
/* stride_out= */ {item_size},
|
||||
/* axes= */ {0},
|
||||
/* forward= */ true,
|
||||
/* data_in= */ w_q_vec.data(),
|
||||
/* data_out= */ w_q_ptr,
|
||||
/* scale= */ 1.0f);
|
||||
return std::make_tuple(w_k, w_q);
|
||||
}
|
||||
|
||||
void multi_upload_bluestein_fft(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
std::vector<array> copies,
|
||||
const Stream& s) {
|
||||
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
|
||||
// algorithm
|
||||
int n = inverse ? out.shape(axis) : in.shape(axis);
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
|
||||
// Broadcast w_q and w_k to the batch size
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast({}, complex64, nullptr, {});
|
||||
array w_q_broadcast({}, complex64, nullptr, {});
|
||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
|
||||
|
||||
auto temp_shape = inverse ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
|
||||
if (real && !inverse) {
|
||||
// Convert float32->complex64
|
||||
copy_gpu(in, temp, CopyType::General, s);
|
||||
} else if (real && inverse) {
|
||||
int back_offset = n % 2 == 0 ? 2 : 1;
|
||||
auto slice_shape = in.shape();
|
||||
slice_shape[axis] -= back_offset;
|
||||
array slice_temp(slice_shape, complex64, nullptr, {});
|
||||
array conj_temp(in.shape(), complex64, nullptr, {});
|
||||
copies.push_back(slice_temp);
|
||||
copies.push_back(conj_temp);
|
||||
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
rstarts[axis] = in.shape(axis) - back_offset;
|
||||
rstrides[axis] = -1;
|
||||
unary_op_gpu({in}, conj_temp, "Conjugate", s);
|
||||
slice_gpu(in, slice_temp, rstarts, rstrides, s);
|
||||
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
|
||||
} else if (inverse) {
|
||||
unary_op_gpu({in}, temp, "Conjugate", s);
|
||||
} else {
|
||||
temp.copy_shared_buffer(in);
|
||||
}
|
||||
|
||||
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
std::vector<std::pair<int, int>> pads;
|
||||
auto padded_shape = out.shape();
|
||||
padded_shape[axis] = plan.bluestein_n;
|
||||
array pad_temp(padded_shape, complex64, nullptr, {});
|
||||
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
|
||||
|
||||
array pad_temp1(padded_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
pad_temp,
|
||||
pad_temp1,
|
||||
axis,
|
||||
/*inverse=*/false,
|
||||
/*real=*/false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/false,
|
||||
s);
|
||||
|
||||
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
|
||||
|
||||
fft_op(
|
||||
pad_temp,
|
||||
pad_temp1,
|
||||
axis,
|
||||
/* inverse= */ true,
|
||||
/* real= */ false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/true,
|
||||
s);
|
||||
|
||||
int offset = plan.bluestein_n - (2 * n - 1);
|
||||
std::vector<int> starts(in.ndim(), 0);
|
||||
std::vector<int> strides(in.ndim(), 1);
|
||||
starts[axis] = plan.bluestein_n - offset - n;
|
||||
slice_gpu(pad_temp1, temp, starts, strides, s);
|
||||
|
||||
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
if (real && !inverse) {
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
slice_gpu(temp1, out, rstarts, strides, s);
|
||||
} else if (real && inverse) {
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
auto inv_n = array({1.0f / n}, {1}, float32);
|
||||
array temp_float(out.shape(), out.dtype(), nullptr, {});
|
||||
copies.push_back(temp_float);
|
||||
copies.push_back(inv_n);
|
||||
|
||||
copy_gpu(temp1, temp_float, CopyType::General, s);
|
||||
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
|
||||
} else if (inverse) {
|
||||
auto inv_n = array({1.0f / n}, {1}, complex64);
|
||||
unary_op_gpu({temp1}, temp, "Conjugate", s);
|
||||
binary_op_gpu({temp, inv_n}, out, "Multiply", s);
|
||||
copies.push_back(inv_n);
|
||||
} else {
|
||||
out.copy_shared_buffer(temp1);
|
||||
}
|
||||
|
||||
copies.push_back(w_k);
|
||||
copies.push_back(w_q);
|
||||
copies.push_back(w_k_broadcast);
|
||||
copies.push_back(w_q_broadcast);
|
||||
copies.push_back(temp);
|
||||
copies.push_back(temp1);
|
||||
copies.push_back(pad_temp);
|
||||
copies.push_back(pad_temp1);
|
||||
}
|
||||
|
||||
void four_step_fft(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
std::vector<array> copies,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
if (plan.bluestein_n == -1) {
|
||||
// Fast no transpose implementation for powers of 2.
|
||||
FourStepParams four_step_params = {
|
||||
/* required= */ true, /* first_step= */ true, plan.n1, plan.n2};
|
||||
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
four_step_params.first_step = false;
|
||||
fft_op(
|
||||
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
copies.push_back(temp);
|
||||
} else {
|
||||
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
}
|
||||
}
|
||||
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
|
||||
if (n == 1) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
if (four_step_params.required) {
|
||||
// Four Step FFT decomposes into two FFTs: n1 on columns, n2 on rows
|
||||
n = four_step_params.first_step ? four_step_params.n1 : four_step_params.n2;
|
||||
}
|
||||
|
||||
// Make sure that the array is contiguous and has stride 1 in the FFT dim
|
||||
std::vector<array> copies;
|
||||
auto check_input = [this, &copies, &s](const array& x) {
|
||||
auto check_input = [&axis, &copies, &s](const array& x) {
|
||||
// TODO: Pass the strides to the kernel so
|
||||
// we can avoid the copy when x is not contiguous.
|
||||
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
|
||||
x.flags().col_contiguous;
|
||||
bool no_copy = x.strides()[axis] == 1 &&
|
||||
(x.flags().row_contiguous || x.flags().col_contiguous);
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
std::vector<size_t> strides;
|
||||
size_t cur_stride = x.shape(axes_[0]);
|
||||
for (int axis = 0; axis < x.ndim(); axis++) {
|
||||
if (axis == axes_[0]) {
|
||||
size_t cur_stride = x.shape(axis);
|
||||
for (int a = 0; a < x.ndim(); a++) {
|
||||
if (a == axis) {
|
||||
strides.push_back(1);
|
||||
} else {
|
||||
strides.push_back(cur_stride);
|
||||
cur_stride *= x.shape(axis);
|
||||
cur_stride *= x.shape(a);
|
||||
}
|
||||
}
|
||||
|
||||
auto flags = x.flags();
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
|
||||
f_stride *= x.shape(i);
|
||||
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
|
||||
b_stride *= x.shape(ri);
|
||||
}
|
||||
// This is probably over-conservative
|
||||
flags.contiguous = false;
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(x.shape(), strides);
|
||||
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.contiguous = data_size == x_copy.size();
|
||||
|
||||
x_copy.set_data(
|
||||
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
|
||||
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
|
||||
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(inputs[0]);
|
||||
const array& in_contiguous = check_input(in);
|
||||
|
||||
// real to complex: n -> (n/2)+1
|
||||
// complex to real: (n/2)+1 -> n
|
||||
auto out_strides = in_contiguous.strides();
|
||||
size_t out_data_size = in_contiguous.data_size();
|
||||
if (in.shape(axis) != out.shape(axis)) {
|
||||
for (int i = 0; i < out_strides.size(); i++) {
|
||||
if (out_strides[i] != 1) {
|
||||
out_strides[i] = out_strides[i] / in.shape(axis) * out.shape(axis);
|
||||
}
|
||||
}
|
||||
out_data_size = out_data_size / in.shape(axis) * out.shape(axis);
|
||||
}
|
||||
|
||||
auto plan = plan_fft(n);
|
||||
if (plan.four_step) {
|
||||
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: allow donation here
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
in_contiguous.data_size(),
|
||||
in_contiguous.strides(),
|
||||
in_contiguous.flags());
|
||||
if (!inplace) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
out_data_size,
|
||||
out_strides,
|
||||
in_contiguous.flags());
|
||||
}
|
||||
|
||||
// We use n / 4 threads by default since radix-4
|
||||
// is the largest single threaded radix butterfly
|
||||
// we currently implement.
|
||||
size_t m = n / 4;
|
||||
size_t batch = in.size() / in.shape(axes_[0]);
|
||||
auto radices = supported_radices();
|
||||
int fft_size = plan.bluestein_n > 0 ? plan.bluestein_n : n;
|
||||
|
||||
// Setup function constants
|
||||
bool power_of_2 = is_power_of_2(fft_size);
|
||||
|
||||
auto make_int = [](int* a, int i) {
|
||||
return std::make_tuple(a, MTL::DataType::DataTypeInt, i);
|
||||
};
|
||||
auto make_bool = [](bool* a, int i) {
|
||||
return std::make_tuple(a, MTL::DataType::DataTypeBool, i);
|
||||
};
|
||||
|
||||
std::vector<MTLFC> func_consts = {
|
||||
make_bool(&inverse, 0), make_bool(&power_of_2, 1)};
|
||||
|
||||
// Start of radix/rader step constants
|
||||
int index = 4;
|
||||
for (int i = 0; i < plan.stockham.size(); i++) {
|
||||
func_consts.push_back(make_int(&plan.stockham[i], index));
|
||||
index += 1;
|
||||
}
|
||||
for (int i = 0; i < plan.rader.size(); i++) {
|
||||
func_consts.push_back(make_int(&plan.rader[i], index));
|
||||
index += 1;
|
||||
}
|
||||
int elems_per_thread = compute_elems_per_thread(plan);
|
||||
func_consts.push_back(make_int(&elems_per_thread, 2));
|
||||
|
||||
int rader_m = n / plan.rader_n;
|
||||
func_consts.push_back(make_int(&rader_m, 3));
|
||||
|
||||
// The overall number of FFTs we're going to compute for this input
|
||||
int size = out.dtype() == float32 ? out.size() : in.size();
|
||||
if (real && inverse && four_step_params.required) {
|
||||
size = out.size();
|
||||
}
|
||||
int total_batch_size = size / n;
|
||||
int threads_per_fft = (fft_size + elems_per_thread - 1) / elems_per_thread;
|
||||
|
||||
// We batch among threadgroups for improved efficiency when n is small
|
||||
int threadgroup_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / fft_size, 1);
|
||||
if (four_step_params.required) {
|
||||
// Require a threadgroup batch size of at least 4 for four step FFT
|
||||
// so we can coalesce the memory accesses.
|
||||
threadgroup_batch_size =
|
||||
std::max(threadgroup_batch_size, MIN_COALESCE_WIDTH);
|
||||
}
|
||||
int threadgroup_mem_size = next_power_of_2(threadgroup_batch_size * fft_size);
|
||||
// FFTs up to 2^20 are currently supported
|
||||
assert(threadgroup_mem_size <= MAX_STOCKHAM_FFT_SIZE);
|
||||
|
||||
// ceil divide
|
||||
int batch_size =
|
||||
(total_batch_size + threadgroup_batch_size - 1) / threadgroup_batch_size;
|
||||
|
||||
if (real && !four_step_params.required) {
|
||||
// We can perform 2 RFFTs at once so the batch size is halved.
|
||||
batch_size = (batch_size + 2 - 1) / 2;
|
||||
}
|
||||
int out_buffer_size = out.size();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
|
||||
// Only required by four step
|
||||
int step = -1;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "fft_" << n;
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
std::string inv_string = inverse ? "true" : "false";
|
||||
std::string real_string = real ? "true" : "false";
|
||||
std::string func_name;
|
||||
if (plan.bluestein_n > 0) {
|
||||
kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_"
|
||||
<< in_type_str << "_" << out_type_str;
|
||||
func_name = "bluestein_fft";
|
||||
} else if (plan.rader_n > 1) {
|
||||
kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str
|
||||
<< "_" << out_type_str;
|
||||
func_name = "rader_fft";
|
||||
} else if (four_step_params.required) {
|
||||
step = four_step_params.first_step ? 0 : 1;
|
||||
kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str
|
||||
<< "_" << out_type_str << "_" << step << "_" << real_string;
|
||||
func_name = "four_step_fft";
|
||||
} else {
|
||||
kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_"
|
||||
<< out_type_str;
|
||||
func_name = "fft";
|
||||
}
|
||||
std::string base_name = kname.str();
|
||||
// We use a specialized kernel for each FFT size
|
||||
kname << "_n" << fft_size << "_inv_" << inverse;
|
||||
std::string hash_name = kname.str();
|
||||
auto template_def = func_name == "four_step_fft" ? get_template_definition(
|
||||
base_name,
|
||||
func_name,
|
||||
threadgroup_mem_size,
|
||||
in_type_str,
|
||||
out_type_str,
|
||||
step,
|
||||
real)
|
||||
: get_template_definition(
|
||||
base_name,
|
||||
func_name,
|
||||
threadgroup_mem_size,
|
||||
in_type_str,
|
||||
out_type_str);
|
||||
auto kernel =
|
||||
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in_contiguous, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
auto group_dims = MTL::Size(1, m, 1);
|
||||
auto grid_dims = MTL::Size(batch, m, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
if (plan.bluestein_n > 0) {
|
||||
// Precomputed twiddle factors for Bluestein's
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
copies.push_back(w_q);
|
||||
copies.push_back(w_k);
|
||||
|
||||
compute_encoder.set_input_array(w_q, 2); // w_q
|
||||
compute_encoder.set_input_array(w_k, 3); // w_k
|
||||
compute_encoder->setBytes(&n, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
} else if (plan.rader_n > 1) {
|
||||
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
|
||||
copies.push_back(b_q);
|
||||
copies.push_back(g_q);
|
||||
copies.push_back(g_minus_q);
|
||||
|
||||
compute_encoder.set_input_array(b_q, 2);
|
||||
compute_encoder.set_input_array(g_q, 3);
|
||||
compute_encoder.set_input_array(g_minus_q, 4);
|
||||
compute_encoder->setBytes(&n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
|
||||
} else if (four_step_params.required) {
|
||||
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(&n, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
|
||||
}
|
||||
|
||||
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
|
||||
auto grid_dims =
|
||||
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
bool inplace,
|
||||
const Stream& s) {
|
||||
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
|
||||
}
|
||||
|
||||
void nd_fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& axes,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const Stream& s) {
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
auto temp_shape = inverse ? in.shape() : out.shape();
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
std::vector<array> temp_arrs = {temp1, temp2};
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int reverse_index = axes.size() - i - 1;
|
||||
// For 5D and above, we don't want to reallocate our two temporary arrays
|
||||
bool inplace = reverse_index >= 3 && i != 0;
|
||||
// Opposite order for fft vs ifft
|
||||
int index = inverse ? reverse_index : i;
|
||||
size_t axis = axes[index];
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
|
||||
}
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (axes_.size() > 1) {
|
||||
nd_fft_op(in, out, axes_, inverse_, real_, s);
|
||||
} else {
|
||||
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
203
mlx/backend/metal/hadamard.cpp
Normal file
203
mlx/backend/metal/hadamard.cpp
Normal file
@@ -0,0 +1,203 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
|
||||
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
|
||||
|
||||
std::string gen_hadamard_codelet(int m) {
|
||||
// Generate a O(m^2) hadamard codelet for a given M
|
||||
// using the hadamard matrices above
|
||||
//
|
||||
// e.g. m = 2
|
||||
// METAL_FUNC void hadamard_m(thread float *x) {
|
||||
// float tmp[2];
|
||||
// tmp[0] = + x[0] + x[1];
|
||||
// tmp[1] = + x[0] - x[1];
|
||||
// for (int i = 0; i < 2; i++) { x[i] = tmp[i]; }
|
||||
// }
|
||||
//
|
||||
auto h_matrices = hadamard_matrices();
|
||||
auto& matrix = h_matrices[m];
|
||||
|
||||
std::ostringstream source;
|
||||
source << "METAL_FUNC void hadamard_radix_m(thread float *x) {" << std::endl;
|
||||
if (m == 1) {
|
||||
source << "}" << std::endl;
|
||||
return source.str();
|
||||
}
|
||||
source << " float tmp[" << m << "];" << std::endl;
|
||||
auto start = 1;
|
||||
auto end = matrix.find('\n', start);
|
||||
|
||||
int index = 0;
|
||||
while (end != std::string_view::npos) {
|
||||
source << " tmp[" << index << "] = ";
|
||||
auto row = matrix.substr(start, end - start);
|
||||
for (int i = 0; i < row.length(); i++) {
|
||||
source << " " << row[i] << " x[" << i << "]";
|
||||
}
|
||||
source << ";" << std::endl;
|
||||
start = end + 1;
|
||||
end = matrix.find('\n', start);
|
||||
index++;
|
||||
}
|
||||
source << " for (int i = 0; i < " << m << "; i++) { x[i] = tmp[i]; }"
|
||||
<< std::endl;
|
||||
source << "}" << std::endl;
|
||||
return source.str();
|
||||
}
|
||||
|
||||
void launch_hadamard(
|
||||
const array& in,
|
||||
array& out,
|
||||
int batch_size,
|
||||
int threads_per,
|
||||
const std::string kernel_name,
|
||||
float scale,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
const auto& lib_name = kernel_name.substr(1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
std::vector<array> copies;
|
||||
// Only support the last axis for now
|
||||
int axis = in.ndim() - 1;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
// TODO(alexbarron) pass strides to kernel to relax this constraint
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(in);
|
||||
|
||||
if (in_contiguous.is_donatable()) {
|
||||
out.move_shared_buffer(in_contiguous);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto [n, m] = decompose_hadamard(in.shape(axis));
|
||||
|
||||
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
|
||||
}
|
||||
|
||||
int max_radix = std::min(n, 16);
|
||||
// Use read_width 2 for m = 28 to avoid register spilling
|
||||
int read_width = (n == 2 || m == 28) ? 2 : 4;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "hadamard_" << n * m << "_" << type_to_name(out);
|
||||
auto kernel_name = kname.str();
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto codelet = gen_hadamard_codelet(m);
|
||||
kernel_source << metal::utils() << codelet << metal::hadamard();
|
||||
kernel_source << get_template_definition(
|
||||
"n" + kernel_name,
|
||||
"hadamard_n",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
max_radix,
|
||||
read_width);
|
||||
kernel_source << get_template_definition(
|
||||
"m" + kernel_name,
|
||||
"hadamard_m",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
m,
|
||||
read_width);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
|
||||
int batch_size = in.size() / n;
|
||||
int threads_per = n / max_radix;
|
||||
|
||||
if (m > 1) {
|
||||
// When m is greater than 1, we decompose the
|
||||
// computation into two uploads to the GPU:
|
||||
//
|
||||
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
|
||||
//
|
||||
// y = h48 @ x
|
||||
//
|
||||
// Upload 1:
|
||||
// tmp = a.reshape(12, 4) @ h4
|
||||
//
|
||||
// Upload 2:
|
||||
// y = h12 @ tmp
|
||||
array temp(in.shape(), in.dtype(), nullptr, {});
|
||||
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
|
||||
copies.push_back(temp);
|
||||
|
||||
launch_hadamard(
|
||||
in_contiguous,
|
||||
temp,
|
||||
batch_size,
|
||||
threads_per,
|
||||
"n" + kernel_name,
|
||||
1.0,
|
||||
s);
|
||||
|
||||
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
|
||||
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
|
||||
batch_size = in.size() / m / read_width / threads_per;
|
||||
launch_hadamard(
|
||||
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
|
||||
} else {
|
||||
launch_hadamard(
|
||||
in_contiguous,
|
||||
out,
|
||||
batch_size,
|
||||
threads_per,
|
||||
"n" + kernel_name,
|
||||
scale_,
|
||||
s);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -95,11 +95,21 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
slice_size *= s;
|
||||
}
|
||||
|
||||
// Launch 2D grid of threads: indices x slice
|
||||
size_t dim0 = out.size() / slice_size;
|
||||
size_t dim1 = slice_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
|
||||
// Launch 3D grid of threads
|
||||
// First two dimensions for the indices, the last one for the slice
|
||||
size_t dim0 = 1;
|
||||
size_t dim1 = 1;
|
||||
if (nidx) {
|
||||
if (inputs[1].ndim() >= 1) {
|
||||
dim0 = inputs[1].shape(0);
|
||||
}
|
||||
if (inputs[1].ndim() >= 2) {
|
||||
dim1 = inputs[1].size() / dim0;
|
||||
}
|
||||
}
|
||||
size_t dim2 = slice_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
@@ -293,7 +303,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
||||
|
||||
size_t out_ndim = out.ndim();
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
|
||||
if (upd_ndim <= 1) {
|
||||
// Placeholder so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 6);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
|
@@ -1,87 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
@@ -1,98 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_two_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
25
mlx/backend/metal/jit/gemv_masked.h
Normal file
25
mlx/backend/metal/jit/gemv_masked.h
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view gemv_masked_kernel = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
|
||||
const device {itype}* mat [[buffer(0)]],
|
||||
const device {itype}* in_vec [[buffer(1)]],
|
||||
device {itype}* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device {outm_t}* out_mask [[buffer(20)]],
|
||||
const device {opm_t}* mat_mask [[buffer(21)]],
|
||||
const device {opm_t}* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
@@ -17,6 +17,9 @@ const char* unary();
|
||||
const char* binary();
|
||||
const char* binary_two();
|
||||
const char* copy();
|
||||
const char* fft();
|
||||
const char* hadamard();
|
||||
const char* quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* softmax();
|
||||
@@ -30,5 +33,6 @@ const char* steel_gemm_splitk();
|
||||
const char* conv();
|
||||
const char* steel_conv();
|
||||
const char* steel_conv_general();
|
||||
const char* gemv_masked();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user