Compare commits

..

23 Commits

Author SHA1 Message Date
Angelos Katharopoulos
f5ece1d98d Tmp 2024-01-22 18:38:11 -08:00
Angelos Katharopoulos
1e4d3c7fb2 Add a compile option to nn.value_and_grad 2024-01-21 21:25:01 -08:00
Awni Hannun
db40990d33 clear output tree from cache when function goes out of scope 2024-01-19 20:52:30 -08:00
Awni Hannun
390e50b163 fix multi-output with compile 2024-01-19 20:52:30 -08:00
Awni Hannun
ed4d867092 simplify tests use compile now 2024-01-19 20:52:30 -08:00
Awni Hannun
1c3f82ca17 remove simplify 2024-01-19 20:52:30 -08:00
Awni Hannun
5c78c16f1c enable and disable compiler 2024-01-19 20:52:30 -08:00
Awni Hannun
ba8ce23597 simplify inputs also 2024-01-19 20:52:30 -08:00
Awni Hannun
df1f6c221b bug fix with move function and compile at exit 2024-01-19 20:52:30 -08:00
Awni Hannun
ecfb72157e more cpp tests 2024-01-19 20:52:30 -08:00
Awni Hannun
d34d10ebac simplify 2024-01-19 20:52:30 -08:00
Awni Hannun
b75ff47098 fix python globals bug, and erase 2024-01-19 20:52:30 -08:00
Awni Hannun
6189111494 fix test 2024-01-19 20:52:30 -08:00
Awni Hannun
4f50935c2c compile works with python closures 2024-01-19 20:52:30 -08:00
Awni Hannun
966b7faef4 fix segfault on python exit 2024-01-19 20:52:30 -08:00
Awni Hannun
0005cfe053 basic python tests 2024-01-19 20:52:30 -08:00
Awni Hannun
9739c72781 fix grad, more tests 2024-01-19 20:52:30 -08:00
Awni Hannun
9d70412a91 fix 2024-01-19 20:52:30 -08:00
Awni Hannun
7a753335d7 clean 2024-01-19 20:52:30 -08:00
Awni Hannun
cd1e5b25cc compile binding 2024-01-19 20:52:30 -08:00
Awni Hannun
21062680d5 basic compile scaffold works 2024-01-19 20:52:30 -08:00
Awni Hannun
264e9ad57e make a move on compile 2024-01-19 20:52:30 -08:00
Awni Hannun
c38d43153b fix tests for linux 2024-01-19 20:52:30 -08:00
224 changed files with 5023 additions and 18454 deletions

View File

@@ -1,8 +1,5 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters:
nightly_build:
type: boolean
@@ -10,9 +7,6 @@ parameters:
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
jobs:
linux_build_and_test:
@@ -32,23 +26,18 @@ jobs:
command: |
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libblas-dev
- run:
name: Install Python package
name: Build python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
- run:
name: Generate package stubs
name: Run the python tests
command: |
python3 setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
python3 -m unittest discover python/tests
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
@@ -63,188 +52,169 @@ jobs:
command: ./build/tests/tests
mac_build_and_test:
macos:
xcode: "15.2.0"
resource_class: macos.m1.large.gen1
machine: true
resource_class: ml-explore/m-builder
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.9
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=3.9
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run:
name: Install Python package
name: Build python package
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
eval "$(conda shell.bash hook)"
conda activate runner-env
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
- run:
name: Generate package stubs
name: Run the python tests
command: |
source env/bin/activate
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3.11 -m pip install .
# eval "$(conda shell.bash hook)"
# conda activate runner-env
# cd examples/extensions && python -m pip install .
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
DEVICE=cpu ./build/tests/tests
command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
build_release:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
xcode_version:
macos_version:
type: string
default: "15.2.0"
build_env:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.large.gen1
default: "14"
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install twine
pip install build
- run:
name: Install Python package
name: Build package
command: |
source env/bin/activate
DEV_RELEASE=1 \
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
PYPI_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
source env/bin/activate
twine upload dist/*
python setup.py bdist_wheel
twine upload dist/* --repository mlx
- store_artifacts:
path: dist/
build_linux_test_release:
build_dev_release:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
extra_env:
macos_version:
type: string
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
default: "14"
steps:
- checkout
- run:
name: Build wheel
name: Install dependencies
command: |
PYTHON=python<< parameters.python_version >>
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
<< parameters.extra_env >> \
pip install twine
- run:
name: Build package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
python setup.py bdist_wheel
twine upload dist/* --repository mlx
- store_artifacts:
path: wheelhouse/
path: dist/
build_package:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
macos_version:
type: string
default: "14"
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install numpy
pip install twine
- run:
name: Build package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel
- store_artifacts:
path: dist/
workflows:
build_and_test:
when:
and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test
- linux_build_and_test
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test
- build_release:
filters:
tags:
@@ -254,53 +224,20 @@ workflows:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
- linux_build_and_test:
requires: [ hold ]
macos_version: ["13", "14"]
nightly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
when: << pipeline.parameters.nightly_build >>
jobs:
- build_release:
- build_package:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"]
macos_version: ["13", "14"]
weekly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
when: << pipeline.parameters.weekly_build >>
jobs:
- build_release:
- build_dev_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_linux_test_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
extra_env: ["PYPI_RELEASE=1"]
macos_version: ["13", "14"]

View File

@@ -5,11 +5,11 @@ repos:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.2.0
rev: 23.12.1
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 5.12.0
hooks:
- id: isort
args:

View File

@@ -10,9 +10,8 @@ MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
@@ -253,4 +252,4 @@ Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.

View File

@@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.5.0)
set(MLX_VERSION 0.0.10)
endif()
# --------------------- Processor tests -------------------------
@@ -31,13 +31,13 @@ if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message(WARNING
"Building for x86_64 on macOS is not supported."
message(WARNING
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, "
" make sure you are building for arm64.")
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
@@ -67,13 +67,15 @@ if (MLX_BUILD_METAL AND NOT METAL_LIB)
set(MLX_BUILD_METAL OFF)
elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
add_compile_definitions(_METAL_)
# Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION
COMMAND_ERROR_IS_FATAL ANY)
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
@@ -121,27 +123,16 @@ else()
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
message(STATUS ${BLAS_LIBRARIES})
message(STATUS ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
mlx
mlx
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>

View File

@@ -1,4 +1,3 @@
include CMakeLists.txt
recursive-include mlx/ *
include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -6,17 +6,15 @@
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning research on Apple silicon,
brought to you by Apple machine learning research.
MLX is an array framework for machine learning on Apple silicon, brought to you
by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
- **Familiar APIs**: MLX has a Python API that closely follows NumPy.
MLX also has a fully featured C++ API, which closely mirrors the Python API.
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
that closely follow PyTorch to simplify building more complex models.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
@@ -70,18 +68,10 @@ in the documentation.
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
**With `pip`**:
```
pip install mlx
```
**With `conda`**:
```
conda install -c conda-forge mlx
```
Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source.

View File

@@ -73,7 +73,6 @@ void time_unary_ops() {
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
@@ -85,9 +84,7 @@ void time_binary_ops() {
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(where, condition, a, b, device);
condition = array({true});
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
@@ -96,9 +93,7 @@ void time_binary_ops() {
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device);
condition = broadcast_to(array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
@@ -106,7 +101,6 @@ void time_binary_ops() {
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
}
void time_strided_ops() {

View File

@@ -72,9 +72,6 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
quant_matmul = {
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
@@ -87,15 +84,6 @@ quant_matmul = {
"quant_matmul_128_8": partial(
_quant_matmul, transpose=False, group_size=128, bits=8
),
"quant_matmul_t_32_2": partial(
_quant_matmul, transpose=True, group_size=32, bits=2
),
"quant_matmul_t_32_4": partial(
_quant_matmul, transpose=True, group_size=32, bits=4
),
"quant_matmul_t_32_8": partial(
_quant_matmul, transpose=True, group_size=32, bits=8
),
"quant_matmul_t_64_2": partial(
_quant_matmul, transpose=True, group_size=64, bits=2
),

View File

@@ -80,8 +80,10 @@ if __name__ == "__main__":
_filter = make_predicate(args.filter, args.negative_filter)
if args.mlx_dtypes:
compare_filtered = lambda x: (
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
compare_filtered = (
lambda x: compare_mlx_dtypes(
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
)
if _filter(x)
else None
)

View File

@@ -1,109 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import math
import random
import mlx.core as mx
from time_utils import time_fn
def bench_gelu():
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
x = mx.random.uniform(shape=(1000, 1024))
def gen_fun(fun):
def bench_fun(x):
for _ in range(10):
x = fun(x)
return x
return bench_fun
time_fn(gen_fun(gelu), x, msg="fixed gelu")
time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")
def randint():
return random.randint(1, x.shape[0])
def gen_fun(fun):
def bench_fun(x, y):
x = x[: randint()]
for _ in range(10):
x = fun(x)
y = fun(y)
return x, y
return bench_fun
y = mx.random.uniform(shape=(1000, 1024))
time_fn(gen_fun(gelu), x, y, msg="variable gelu")
time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
time_fn(
gen_fun(mx.compile(gelu, shapeless=True)),
x,
y,
msg="shapeless variable gelu",
)
def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias)
def layernorm(x):
x = x.astype(mx.float32)
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + 1e-4)
x = x.astype(mx.float16)
return weight * x + bias
x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)
def gen_fun(fun):
def bench_fun(x):
for _ in range(10):
x = fun(x)
return x
return bench_fun
time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")
def randint():
return random.randint(1, x.shape[0])
def gen_fun(fun):
def bench_fun(x):
x = x[: randint()]
for _ in range(10):
x = fun(x)
return x
return bench_fun
random.seed(0)
time_fn(gen_fun(layernorm), x, msg="variable layernorm")
random.seed(0)
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
random.seed(0)
time_fn(
gen_fun(mx.compile(layernorm, shapeless=True)),
x,
msg="shapeless variable layernorm",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Compile benchmarks.")
args = parser.parse_args()
bench_gelu()
bench_layernorm()

View File

@@ -1,129 +0,0 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding)
f_pt = make_pt_conv_2D(strides, padding)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
)
for dtype in dtypes:
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
for N, H, W, C, kH, kW, O, strides, padding in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,53 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from time import time
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_gather_mlx(x_shape, idx_shape):
def gather(x, idx):
mx.eval(x[idx])
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
x = mx.random.normal(x_shape).astype(mx.float32)
runtime = measure_runtime(gather, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_gather_torch(x_shape, idx_shape, device):
def gather(x, idx, device):
_ = x[idx]
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
x = torch.randn(x_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
idx_shapes = [(1_000_000,), (100_000,), ()]
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_gather_mlx(x_shape, idx_shape)
benchmark_gather_torch(x_shape, idx_shape, device=device)

View File

@@ -0,0 +1,198 @@
# Copyright © 2023 Apple Inc.
import math
import time
import jax
import jax.numpy as jnp
from flax import linen as nn
class RoPE(nn.Module):
dims: int
traditional: bool = False
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = jnp.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
dtype=jnp.float32,
):
D = D // 2
positions = jnp.arange(offset, N, dtype=dtype)
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)
return costheta, sintheta
@nn.compact
def __call__(self, x, offset: int = 0):
shape = x.shape
x = x.reshape((-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.reshape(shape)
class LlamaAttention(nn.Module):
dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
num_heads = self.num_heads
dims = self.dims
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = jnp.concatenate([key_cache, keys], axis=2)
values = jnp.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = jax.nn.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
dims: int
mlp_dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
dims = self.dims
mlp_dims = self.mlp_dims
num_heads = self.num_heads
self.attention = LlamaAttention(dims, num_heads, dtype)
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = jax.nn.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
dtype = jnp.float16
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
x = jax.random.normal(k1, (1, 1, D), dtype)
cache = [
jax.random.normal(k2, [1, H, C, D // H], dtype),
jax.random.normal(k3, [1, H, C, D // H], dtype),
]
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
params = layer.init(k4, x, mask=None, cache=cache)["params"]
@jax.jit
def model_fn(x, mask, cache):
return layer.apply({"params": params}, x, mask=mask, cache=cache)
T = measure(model_fn, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,118 @@
# Copyright © 2023 Apple Inc.
import math
import time
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, False)
self.key_proj = nn.Linear(dims, dims, False)
self.value_proj = nn.Linear(dims, dims, False)
self.out_proj = nn.Linear(dims, dims, False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, False)
self.linear2 = nn.Linear(dims, mlp_dims, False)
self.linear3 = nn.Linear(mlp_dims, dims, False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
mx.eval(y, c)
start = time.time()
rs = []
for i in range(5):
y, c = model(x, mask=None, cache=cache)
rs.append((y, c))
mx.eval(rs)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
mx.set_default_device(mx.gpu)
dtype = mx.float16
layer = LlamaEncoderLayer(D, F, H)
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
x = mx.random.normal([1, 1, D], dtype=dtype)
cache = [
mx.random.normal([1, H, C, D // H], dtype=dtype),
mx.random.normal([1, H, C, D // H], dtype=dtype),
]
mx.eval(x, cache)
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,199 @@
# Copyright © 2023 Apple Inc.
import math
import time
import torch
import torch.mps
import torch.nn as nn
def sync_if_needed(x):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
class RoPE(nn.Module):
def __init__(self, dims: int, traditional: bool = False):
super().__init__()
self.dims = dims
self.traditional = traditional
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
else:
rx = torch.cat([rx1, rx2], dim=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
return rx
def forward(self, x, offset: int = 0):
shape = x.shape
x = x.view(-1, shape[-2], shape[-1])
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.view(*shape)
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
device="cpu",
dtype=torch.float32,
):
D = D // 2
positions = torch.arange(offset, N, dtype=dtype, device=device)
freqs = torch.exp(
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
)
theta = positions.view(-1, 1) * freqs.view(1, -1)
costheta = torch.cos(theta)
sintheta = torch.sin(theta)
return costheta, sintheta
class RMSNorm(nn.Module):
def __init__(self, dims: int, epsilon: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones((dims,)))
self.epsilon = epsilon
def forward(self, x):
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
return self.gamma * x * n
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def forward(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = torch.cat([key_cache, keys], dim=2)
values = torch.cat([value_cache, values], dim=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = torch.softmax(scores, dim=-1)
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = RMSNorm(dims)
self.norm2 = RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def forward(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = torch.nn.functional.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
@torch.no_grad()
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
device = torch.device("mps")
dtype = torch.float16
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
x = torch.randn(1, 1, D).to(device).to(dtype)
cache = [
torch.randn(1, H, C, D // H).to(device).to(dtype),
torch.randn(1, H, C, D // H).to(device).to(dtype),
]
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -1,35 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def time_rope():
rope = nn.RoPE(4096)
# vec
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_mat, x)
if __name__ == "__main__":
time_rope()

View File

@@ -1,96 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx):
dst[*idx] = x
mx.eval(dst)
idx = []
for idx_shape in idx_shapes:
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
x = mx.random.normal(x_shape).astype(mx.float32)
dst = mx.random.normal(dst_shape).astype(mx.float32)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def gather(dst, x, idx, device):
dst[*idx] = x
if device == torch.device("mps"):
torch.mps.synchronize()
idx = []
for idx_shape in idx_shapes:
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
dst_shapes = [
(10, 64),
(100_000, 64),
(1_000_000, 64),
(100_000,),
(2_000_00,),
(20_000_000,),
(10000, 64),
(100, 64),
(100, 10_000, 64),
(10, 100, 100, 21),
(1_000, 1_000, 10),
]
idx_shapes = [
[(1_000_000,)],
[(1_000_000,)],
[(100_000,)],
[(1_000_000,)],
[(20_000_000,)],
[(20_000_000,)],
[(1000000,)],
[(10000000,)],
[(1_000,)],
[(10_000,)],
[(1_000,), (1_000,)],
]
x_shapes = [
(1_000_000, 64),
(1_000_000, 64),
(100_000, 64),
(1_000_000,),
(20_000_000,),
(20_000_000,),
(1000000, 64),
(10000000, 64),
(1_000, 10_000, 64),
(10_000, 100, 100, 21),
(1_000, 10),
]
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@@ -44,13 +44,6 @@ def time_matmul():
time_fn(mx.matmul, a, b)
def time_maximum():
a = mx.random.uniform(shape=(32, 1024, 1024))
b = mx.random.uniform(shape=(32, 1024, 1024))
mx.eval(a, b)
time_fn(mx.maximum, a, b)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -108,7 +101,6 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_maximum()
time_exp()
time_negative()
time_logsumexp()

View File

@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
# Copyright © 2023 Apple Inc.
import time
@@ -6,11 +6,7 @@ import mlx.core as mx
def time_fn(fn, *args, **kwargs):
msg = kwargs.pop("msg", None)
if msg:
print(f"Timing {msg} ...", end=" ")
else:
print(f"Timing {fn.__name__} ...", end=" ")
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
@@ -24,15 +20,3 @@ def time_fn(fn, *args, **kwargs):
msec = 1e3 * (toc - tic) / num_iters
print(f"{msec:.5f} msec")
def measure_runtime(fn, **kwargs):
# Warmup
for _ in range(5):
fn(**kwargs)
tic = time.time()
iters = 100
for _ in range(iters):
fn(**kwargs)
return (time.time() - tic) * 1000 / iters

1
docs/.gitignore vendored
View File

@@ -1,3 +1,2 @@
src/python/_autosummary*/
src/python/nn/_autosummary*/
src/python/optimizers/_autosummary*/

Binary file not shown.

Before

Width:  |  Height:  |  Size: 76 KiB

After

Width:  |  Height:  |  Size: 7.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 48 KiB

View File

@@ -0,0 +1,19 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{#{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}#}

View File

@@ -12,7 +12,7 @@ import mlx.core as mx
project = "MLX"
copyright = "2023, MLX Contributors"
author = "MLX Contributors"
version = ".".join(mx.__version__.split(".")[:3])
version = ".".join(mx.__version__.split()[:-1])
release = version
# -- General configuration ---------------------------------------------------
@@ -26,7 +26,6 @@ extensions = [
python_use_unqualified_type_names = True
autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = {
"https://docs.python.org/3": None,
@@ -49,12 +48,10 @@ html_theme_options = {
"repository_url": "https://github.com/ml-explore/mlx",
"use_repository_button": True,
"navigation_with_keys": False,
"logo": {
"image_light": "_static/mlx_logo.png",
"image_dark": "_static/mlx_logo_dark.png",
},
}
html_logo = "_static/mlx_logo.png"
# -- Options for HTMLHelp output ---------------------------------------------

View File

@@ -35,7 +35,7 @@ However, you work with vector math libraries often and realize that the
You would really like the part of your applications that does this operation
on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how to add
our assumptions on to you, let's also assume that you want to learn how add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.
@@ -677,9 +677,9 @@ Let's look at the overall directory structure first.
Binding to Python
^^^^^^^^^^^^^^^^^^
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple!
We use PyBind11_ to build a Python API for the C++ library. Since bindings
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
are already provided, adding our :meth:`axpby` becomes very simple!
.. code-block:: C++
@@ -927,18 +927,18 @@ Results:
We see some modest improvements right away!
This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`!
This operation is now good to be used to build other operations,
in :class:`mlx.nn.Module` calls, and also as a part of graph
transformations such as :meth:`grad` and :meth:`simplify`!
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx <code>`_.
The full example code is available in `mlx-examples <code>`_.
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
.. code: `TODO_LINK/extensions`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc

View File

@@ -41,7 +41,6 @@ are the CPU and GPU.
usage/indexing
usage/saving_and_loading
usage/function_transforms
usage/compile
usage/numpy
usage/using_streams

View File

@@ -9,10 +9,9 @@ Devices and Streams
:toctree: _autosummary
Device
Stream
default_device
set_default_device
Stream
default_stream
new_stream
set_default_stream
stream

View File

@@ -9,4 +9,3 @@ Linear Algebra
:toctree: _autosummary
norm
qr

View File

@@ -180,4 +180,3 @@ In detail:
nn/layers
nn/functions
nn/losses
nn/init

View File

@@ -12,24 +12,12 @@ simple functions.
:toctree: _autosummary_functions
:template: nn-module-template.rst
elu
gelu
gelu_approx
gelu_fast_approx
glu
hardswish
leaky_relu
log_sigmoid
log_softmax
mish
prelu
relu
relu6
selu
sigmoid
silu
softmax
softplus
softshrink
step
tanh

View File

@@ -1,45 +0,0 @@
.. _init:
.. currentmodule:: mlx.nn.init
Initializers
------------
The ``mlx.nn.init`` package contains commonly used initializers for neural
network parameters. Initializers return a function which can be applied to any
input :obj:`mlx.core.array` to produce an initialized output.
For example:
.. code:: python
import mlx.core as mx
import mlx.nn as nn
init_fn = nn.init.uniform()
# Produces a [2, 2] uniform matrix
param = init_fn(mx.zeros((2, 2)))
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
distribution, you can do:
.. code:: python
import mlx.nn as nn
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
init_fn = nn.init.uniform(low=-0.1, high=0.1)
model.apply(init_fn)
.. autosummary::
:toctree: _autosummary
constant
normal
uniform
identity
glorot_normal
glorot_uniform
he_normal
he_uniform

View File

@@ -10,8 +10,6 @@ Layers
:template: nn-module-template.rst
ALiBi
AvgPool1d
AvgPool2d
BatchNorm
Conv1d
Conv2d
@@ -24,8 +22,6 @@ Layers
InstanceNorm
LayerNorm
Linear
MaxPool1d
MaxPool2d
Mish
MultiHeadAttention
PReLU
@@ -37,7 +33,5 @@ Layers
Sequential
SiLU
SinusoidalPositionalEncoding
Softshrink
Step
Transformer
Upsample

View File

@@ -18,7 +18,6 @@ Loss Functions
kl_div_loss
l1_loss
log_cosh_loss
margin_ranking_loss
mse_loss
nll_loss
smooth_l1_loss

View File

@@ -11,7 +11,6 @@ Module
:toctree: _autosummary
Module.training
Module.state
.. rubric:: Methods

View File

@@ -25,9 +25,6 @@ Operations
argpartition
argsort
array_equal
atleast_1d
atleast_2d
atleast_3d
broadcast_to
ceil
clip
@@ -35,12 +32,9 @@ Operations
convolve
conv1d
conv2d
conv_general
cos
cosh
dequantize
diag
diagonal
divide
divmod
equal

View File

@@ -29,8 +29,19 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
.. toctree::
.. currentmodule:: mlx.optimizers
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
OptimizerState
Optimizer
SGD
RMSprop
Adagrad
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@@ -1,20 +0,0 @@
.. _common_optimizers:
Common Optimizers
=================
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@@ -1,23 +0,0 @@
Optimizer
=========
.. currentmodule:: mlx.optimizers
.. autoclass:: Optimizer
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Optimizer.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Optimizer.apply_gradients
Optimizer.init
Optimizer.update

View File

@@ -1,15 +0,0 @@
.. _schedulers:
Schedulers
==========
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
cosine_decay
exponential_decay
join_schedules
linear_schedule
step_decay

View File

@@ -9,11 +9,9 @@ Transforms
:toctree: _autosummary
eval
compile
disable_compile
enable_compile
grad
value_and_grad
jvp
vjp
vmap
simplify

View File

@@ -1,430 +0,0 @@
.. _compile:
Compilation
===========
.. currentmodule:: mlx.core
MLX has a :func:`compile` function transformation which compiles computation
graphs. Function compilation results in smaller graphs by merging common work
and fusing certain operations. In many cases this can lead to big improvements
in run-time and memory use.
Getting started with :func:`compile` is simple, but there are some edge cases
that are good to be aware of for more complex graphs and advanced usage.
Basics of Compile
-----------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
function multiple times will not initiate a new compilation. This means you
should typically compile functions that you plan to use more than once.
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
compiled_fun = mx.compile(fun)
# Compiled here
compiled_fun(x, y)
# Not compiled again
compiled_fun(x, y)
# Not compiled again
mx.compile(fun)(x, y)
There are some important cases to be aware of that can cause a function to
be recompiled:
* Changing the shape or number of dimensions
* Changing the type of any of the inputs
* Changing the number of inputs to the function
In certain cases only some of the compilation stack will be rerun (for
example when changing the shapes) and in other cases the full compilation
stack will be rerun (for example when changing the types). In general you
should avoid compiling functions too frequently.
Another idiom to watch out for is compiling functions which get created and
destroyed frequently. This can happen, for example, when compiling an anonymous
function in a loop:
.. code-block:: python
a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
Example Speedup
---------------
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
Transformer-based models. The implementation involves several unary and binary
element-wise operations:
.. code-block:: python
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
use it with large arrays it will be memory bandwidth bound. However, all of
the operations in the ``gelu`` are fusible into a single kernel with
:func:`compile`. This can speedup both cases considerably.
Let's compare the runtime of the regular function versus the compiled
function. We'll use the following timing helper which does a warm up and
handles synchronization:
.. code-block:: python
import time
def timeit(fun, x):
# warm up
for _ in range(10):
mx.eval(fun(x))
tic = time.perf_counter()
for _ in range(100):
mx.eval(fun(x))
toc = time.perf_counter()
tpi = 1e3 * (toc - tic) / 100
print(f"Time per iteration {tpi:.3f} (ms)")
Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging
---------
When a compiled function is first called, it is traced with placeholder
inputs. This means you can't evaluate arrays (for example to print their
contents) inside compiled functions.
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Crash
return mx.exp(z)
fun(mx.array(5.0))
For debugging, inspecting arrays can be helpful. One way to do that is to
globally disable compilation using the :func:`disable_compile` function or
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
``fun`` is compiled:
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Okay
return mx.exp(z)
mx.disable_compile()
fun(mx.array(5.0))
Pure Functions
--------------
Compiled functions are intended to be *pure*; that is they should not have side
effects. For example:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)
After the first call of ``fun``, the ``state`` list will hold a placeholder
array. The placeholder does not have any data; it is only used to build the
computation graph. Printing such an array results in a crash.
You have two options to deal with this. The first option is to simply return
``state`` as an output:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
_, state = fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
In some cases returning updated state can be pretty inconvenient. Hence,
:func:`compile` has a parameter to capture implicit outputs:
.. code-block:: python
from functools import partial
state = []
# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
This is particularly useful for compiling a function which includes an update
to a container of arrays, as is commonly done when training the parameters of a
:class:`mlx.nn.Module`.
Compiled functions will also treat any inputs not in the parameter list as
constants. For example:
.. code-block:: python
state = [mx.array(1.0)]
@mx.compile
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
In order to have the change of state reflected in the outputs of ``fun`` you
again have two options. The first option is to simply pass ``state`` as input
to the function. In some cases this can be pretty inconvenient. Hence,
:func:`compile` also has a parameter to capture implicit inputs:
.. code-block:: python
from functools import partial
state = [mx.array(1.0)]
# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
of a common setup: training a model with :obj:`mlx.nn.Module` using an
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Perform 10 steps of gradient descent
for it in range(10):
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
# Perform 10 steps of gradient descent
for it in range(10):
loss = step(x, y)
# Evaluate the model and optimizer state
mx.eval(state)
print(loss)
.. note::
If you are using a module which performs random sampling such as
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
optimizer.state, mx.random.state]``.
.. note::
For more examples of compiling full training graphs checkout the `MLX
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
Transformations with Compile
----------------------------
In MLX function transformations are composable. You can apply any function
transformation to the output of any other function transformation. For more on
this, see the documentation on :ref:`function transforms
<function_transforms>`.
Compiling transformed functions works just as expected:
.. code-block:: python
grad_fn = mx.grad(mx.exp)
compiled_grad_fn = mx.compile(grad_fn)
# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))
# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))
.. note::
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`
the most opportunity to optimize the computation graph:
.. code-block:: python
@mx.compile
def inner(x):
return mx.exp(-mx.abs(x))
def outer(x):
inner(inner(x))
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)

View File

@@ -5,12 +5,9 @@ Function Transforms
.. currentmodule:: mlx.core
MLX uses composable function transformations for automatic differentiation,
vectorization, and compute graph optimizations. To see the complete list of
function transformations check-out the :ref:`API documentation <transforms>`.
The key idea behind composable function transformations is that every
transformation returns a function which can be further transformed.
MLX uses composable function transformations for automatic differentiation and
vectorization. The key idea behind composable function transformations is that
every transformation returns a function which can be further transformed.
Here is a simple example:
@@ -39,10 +36,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep
getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
depth. See the following sections for more information on :ref:`automatic
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
depth. To see the complete list of function transformations check-out the
:ref:`API documentation <transforms>`. See the following sections for more
information on :ref:`automatic differentiaion <auto diff>` and
:ref:`automatic vectorization <vmap>`.
Automatic Differentiation
-------------------------

View File

@@ -20,7 +20,7 @@ Transforming Compute Graphs
Lazy evaluation let's us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations.
:func:`vmap` and graph optimizations like :func:`simplify`.
Currently, MLX does not compile and rerun compute graphs. They are all
generated dynamically. However, lazy evaluation makes it much easier to

View File

@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.27)
cmake_minimum_required(VERSION 3.24)
project(mlx_sample_extensions LANGUAGES CXX)
@@ -63,4 +63,4 @@ target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()
endif()

View File

@@ -3,10 +3,9 @@ target_sources(
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
@@ -21,7 +20,7 @@ target_sources(
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
else()
target_sources(

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <functional>
@@ -82,13 +82,6 @@ array::array(std::initializer_list<float> data)
init(data.begin());
}
array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
dtype)) {
init(data.begin());
}
/* Build an array from a shared buffer */
array::array(
allocator::Buffer data,
@@ -104,13 +97,11 @@ void array::detach() {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->depth = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
array_desc_->depth = 0;
array_desc_->primitive = nullptr;
}
@@ -164,14 +155,6 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::move_shared_buffer(array other) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = other.strides();
array_desc_->flags = other.flags();
array_desc_->data_size = other.data_size();
array_desc_->data_ptr = other.array_desc_->data_ptr;
}
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
: shape(shape), dtype(dtype) {
std::tie(size, strides) = cum_prod(shape);
@@ -187,11 +170,9 @@ array::ArrayDesc::ArrayDesc(
primitive(std::move(primitive)),
inputs(inputs) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : this->inputs) {
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayDesc::ArrayDesc(
@@ -204,11 +185,9 @@ array::ArrayDesc::ArrayDesc(
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : this->inputs) {
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)

View File

@@ -41,9 +41,6 @@ class array {
/* Special case so empty lists default to float32. */
array(std::initializer_list<float> data);
/* Special case so array({}, type) is an empty array. */
array(std::initializer_list<int> data, Dtype dtype);
template <typename T>
array(
std::initializer_list<T> data,
@@ -124,9 +121,6 @@ class array {
template <typename T>
T item();
template <typename T>
T item() const;
struct ArrayIterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = size_t;
@@ -246,11 +240,6 @@ class array {
return array_desc_->inputs;
}
/** True indicates the arrays buffer is safe to reuse */
bool is_donatable() const {
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
}
/** The array's siblings. */
const std::vector<array>& siblings() const {
return array_desc_->siblings;
@@ -273,11 +262,6 @@ class array {
return outputs;
};
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
uint16_t graph_depth() const {
return array_desc_->depth;
}
/** Detach the array from the graph. */
void detach();
@@ -298,12 +282,6 @@ class array {
return array_desc_->data->buffer;
};
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
@@ -344,8 +322,6 @@ class array {
void copy_shared_buffer(const array& other);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_;
}
@@ -388,9 +364,6 @@ class array {
// The arrays position in the output list
uint32_t position{0};
// The depth of the array in the graph.
uint16_t depth{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc(
@@ -409,7 +382,7 @@ class array {
// The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs
// and the list of array's inputs for the primitive.
// and a the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
};
@@ -460,18 +433,6 @@ T array::item() {
return *data<T>();
}
template <typename T>
T array::item() const {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
if (!is_evaled()) {
throw std::invalid_argument(
"item() const can only be called on evaled arrays");
}
return *data<T>();
}
template <typename It>
void array::init(It src) {
set_data(allocator::malloc(size() * size_of(dtype())));

View File

@@ -46,14 +46,6 @@ inline void matmul_cblas_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
@@ -97,14 +89,6 @@ inline void matmul_bnns_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
const BNNSLayerParametersBroadcastMatMul gemm_params{
@@ -217,4 +201,4 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <cmath>
@@ -35,9 +35,6 @@ DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
@@ -53,18 +50,13 @@ DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
@@ -72,18 +64,32 @@ DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, size);
} else if (in.dtype() == int32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, size);
} else if (is_unsigned(in.dtype())) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
eval(inputs, out);
unary(in, out, AbsOp());
}
}
@@ -131,8 +137,12 @@ void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -143,8 +153,12 @@ void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -155,8 +169,12 @@ void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -167,8 +185,12 @@ void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -179,8 +201,12 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -191,8 +217,12 @@ void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -204,23 +234,30 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
if (in.flags().contiguous) {
auto allocfn = [&in, &out]() {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
};
// Use accelerate functions if possible
if (in.dtype() == float32 && out.dtype() == uint32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vfixu32(
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == float32 && out.dtype() == int32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == uint32 && out.dtype() == float32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vfltu32(
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
} else if (in.dtype() == int32 && out.dtype() == float32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
}
@@ -232,8 +269,12 @@ void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -244,8 +285,12 @@ void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -290,12 +335,55 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
// TODO: Avoid code duplication with the common backend.
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
a,
b,
out,
RemainderFn{},
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* o, auto n) {
int num_el = n;
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
});
} else {
binary(a, b, out, RemainderFn{});
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
@@ -322,8 +410,12 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
switch (base_) {
case Base::e:
vvlogf(
@@ -347,8 +439,12 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
@@ -360,6 +456,47 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x > y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x < y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@@ -389,8 +526,13 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return -x; });
}
@@ -403,13 +545,7 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
int size = a.size();
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
} else {
eval(inputs, out);
@@ -451,8 +587,12 @@ void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -463,8 +603,12 @@ void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -475,8 +619,12 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return x * x; });
@@ -487,8 +635,12 @@ void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
if (recip_) {
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
} else {
@@ -543,8 +695,12 @@ void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -555,8 +711,12 @@ void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);

View File

@@ -24,6 +24,8 @@ void _qmm_t_4_64(
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;

View File

@@ -274,12 +274,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@@ -1,42 +1,8 @@
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE)
else()
set(COMPILER ${CMAKE_CXX_COMPILER})
endif()
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${COMPILER}
${PROJECT_SOURCE_DIR}
${CLANG}
DEPENDS make_compiled_preamble.sh
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h
)
add_custom_target(
cpu_compiled_preamble
DEPENDS compiled_preamble.cpp
)
add_dependencies(mlx cpu_compiled_preamble)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
@@ -44,14 +10,10 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)

View File

@@ -7,7 +7,6 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/binary_two.h"
#include "mlx/backend/common/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -74,7 +73,7 @@ void Add::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Add());
binary(a, b, out, [](auto x, auto y) { return x + y; });
}
void DivMod::eval(
@@ -136,56 +135,88 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Divide());
binary(a, b, out, [](auto x, auto y) { return x / y; });
}
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Remainder());
binary(a, b, out, RemainderFn{});
}
void Equal::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (equal_nan_) {
comparison_op(inputs[0], inputs[1], out, detail::NaNEqual());
comparison_op(inputs[0], inputs[1], out, [](auto x, auto y) {
return x == y || (std::isnan(x) && std::isnan(y));
});
} else {
comparison_op(inputs[0], inputs[1], out, detail::Equal());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x == y; });
}
}
void Greater::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x > y; });
}
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x >= y; });
}
void Less::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x < y; });
}
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x <= y; });
}
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto op = [](auto x, auto y) {
constexpr float inf = std::numeric_limits<float>::infinity();
auto maxval = (x > y) ? x : y;
auto minval = (x > y) ? y : x;
return (minval == -inf || maxval == inf)
? maxval
: static_cast<decltype(x)>(
maxval + std::log1p(std::exp(minval - maxval)));
};
if (is_floating_point(out.dtype())) {
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
binary_op<float>(a, b, out, op);
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
binary_op<float16_t>(a, b, out, op);
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
binary_op<bfloat16_t>(a, b, out, op);
} else {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
@@ -202,40 +233,65 @@ void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Maximum());
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
void Minimum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Minimum());
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
void Multiply::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Multiply());
binary(a, b, out, [](auto x, auto y) { return x * y; });
}
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x != y; });
}
struct PowerFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
return std::pow(base, exp);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
if (exp < 0) {
throw std::invalid_argument(
"Integers cannot be raise to negative powers");
}
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
};
void Power::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Power());
binary(a, b, out, PowerFn{});
}
void Subtract::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Subtract());
binary(a, b, out, [](auto x, auto y) { return x - y; });
}
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
@@ -9,7 +10,7 @@ namespace mlx::core {
namespace {
enum class BinaryOpType {
enum BinaryOpType {
ScalarScalar,
ScalarVector,
VectorScalar,
@@ -20,17 +21,17 @@ enum class BinaryOpType {
BinaryOpType get_binary_op_type(const array& a, const array& b) {
BinaryOpType bopt;
if (a.data_size() == 1 && b.data_size() == 1) {
bopt = BinaryOpType::ScalarScalar;
bopt = ScalarScalar;
} else if (a.data_size() == 1 && b.flags().contiguous) {
bopt = BinaryOpType::ScalarVector;
bopt = ScalarVector;
} else if (b.data_size() == 1 && a.flags().contiguous) {
bopt = BinaryOpType::VectorScalar;
bopt = VectorScalar;
} else if (
a.flags().row_contiguous && b.flags().row_contiguous ||
a.flags().col_contiguous && b.flags().col_contiguous) {
bopt = BinaryOpType::VectorVector;
bopt = VectorVector;
} else {
bopt = BinaryOpType::General;
bopt = General;
}
return bopt;
}
@@ -39,83 +40,29 @@ void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt,
bool donate_with_move = false) {
BinaryOpType bopt) {
switch (bopt) {
case BinaryOpType::ScalarScalar:
case ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
}
case ScalarVector:
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
break;
case BinaryOpType::VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
case VectorScalar:
case VectorVector:
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
break;
case BinaryOpType::VectorVector:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
break;
case BinaryOpType::General:
if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (
b.is_donatable() && b.flags().row_contiguous &&
b.itemsize() == out.itemsize() && b.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
case General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
break;
}
}
@@ -424,25 +371,25 @@ void binary_op(
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::ScalarScalar) {
if (bopt == ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
if (bopt == ScalarVector) {
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
if (bopt == VectorScalar) {
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
if (bopt == VectorVector) {
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
@@ -475,17 +422,17 @@ void binary_op(
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
bopt = VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
bopt = VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
bopt = ScalarVector;
dim = d;
}
@@ -495,20 +442,20 @@ void binary_op(
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General;
bopt = General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case BinaryOpType::VectorVector:
case VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
break;
case BinaryOpType::VectorScalar:
case VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
break;
case BinaryOpType::ScalarVector:
case ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
break;
default:

View File

@@ -260,14 +260,14 @@ void binary_op(
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::ScalarScalar) {
if (bopt == ScalarScalar) {
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
if (bopt == ScalarVector) {
opsv(
a.data<T>(),
b.data<T>(),
@@ -278,7 +278,7 @@ void binary_op(
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
if (bopt == VectorScalar) {
opvs(
a.data<T>(),
b.data<T>(),
@@ -289,7 +289,7 @@ void binary_op(
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
if (bopt == VectorVector) {
opvv(
a.data<T>(),
b.data<T>(),
@@ -327,17 +327,17 @@ void binary_op(
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
bopt = VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
bopt = VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
bopt = ScalarVector;
dim = d;
}
@@ -347,20 +347,20 @@ void binary_op(
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General;
bopt = General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case BinaryOpType::VectorVector:
case VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
break;
case BinaryOpType::VectorScalar:
case VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
break;
case BinaryOpType::ScalarVector:
case ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
break;
default:

View File

@@ -1,507 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include <filesystem>
#include <list>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
#include "mlx/backend/common/utils.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
std::ostringstream os;
std::ostringstream constant_hasher;
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
a.primitive().print(os);
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
void print_constant(std::ostream& os, const array& x) {
switch (x.dtype()) {
case float32:
return print_float_constant<float>(os, x);
case float16:
return print_float_constant<float16_t>(os, x);
case bfloat16:
return print_float_constant<bfloat16_t>(os, x);
case complex64:
return print_complex_constant<complex64_t>(os, x);
case int8:
return print_int_constant<int8_t>(os, x);
case int16:
return print_int_constant<int16_t>(os, x);
case int32:
return print_int_constant<int32_t>(os, x);
case int64:
return print_int_constant<int64_t>(os, x);
case uint8:
return print_int_constant<uint8_t>(os, x);
case uint16:
return print_int_constant<uint16_t>(os, x);
case uint32:
return print_int_constant<uint32_t>(os, x);
case uint64:
return print_int_constant<uint64_t>(os, x);
case bool_:
os << std::boolalpha << x.item<bool>();
return;
default:
throw std::runtime_error("Unsupported constant type");
}
}
std::string get_type_string(Dtype d) {
switch (d) {
case float32:
return "float";
case float16:
return "float16_t";
case bfloat16:
return "bfloat16_t";
case complex64:
return "complex64_t";
case bool_:
return "bool";
case int8:
return "int8_t";
case int16:
return "int16_t";
case int32:
return "int32_t";
case int64:
return "int64_t";
case uint8:
return "uint8_t";
case uint16:
return "uint16_t";
case uint32:
return "uint32_t";
case uint64:
return "uint64_t";
default: {
std::ostringstream msg;
msg << "Unsupported compilation type " << d;
throw std::runtime_error(msg.str());
}
}
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::string& source_code = "") {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
if (!lib) {
std::ostringstream msg;
msg << "Could not load C++ shared library " << dlerror();
throw std::runtime_error(msg.str());
}
}
~DLib() {
dlclose(lib);
}
void* lib;
};
// Statics to cache compiled libraries and functions
static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels;
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
if (source_code.empty()) {
return nullptr;
}
std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_name << ".so";
auto shared_lib_path = get_temp_file(shared_lib_name.str());
bool lib_exists = false;
{
std::ifstream f(shared_lib_path.c_str());
lib_exists = f.good();
}
if (!lib_exists) {
// Open source file and write source code to it
std::ostringstream source_file_name;
source_file_name << kernel_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str());
std::ofstream source_file(source_file_path);
source_file << source_code;
source_file.close();
std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
<< source_file_path << " -o " << shared_lib_path;
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
if (return_code) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
<< " with error code " << return_code << "." << std::endl;
throw std::runtime_error(msg.str());
}
}
// load library
libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
<< kernel_name << std::endl
<< dlerror();
throw std::runtime_error(msg.str());
}
kernels.insert({kernel_name, fun});
return fun;
}
inline void build_kernel(
std::ostream& os,
const std::string& kernel_name,
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous,
int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer;
// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
// Skip constants from the input list
if (is_constant(x)) {
continue;
}
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
}
}
// Add the output arguments
for (auto& x : outputs) {
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
if (contiguous) {
os << " for (size_t i = 0; i < size; ++i) {" << std::endl;
} else {
for (int d = 0; d < ndim; ++d) {
os << " for (int i" << d << " = 0; i" << d << " < shape[" << d
<< "]; ++i" << d << ") {" << std::endl;
}
}
// Read the inputs in tmps
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ";" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[i];" << std::endl;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = *"
<< xname << ";" << std::endl;
}
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
if (contiguous) {
os << " " << namer.get_name(x) << "[i] = tmp_" << namer.get_name(x)
<< ";" << std::endl;
} else {
os << " *" << namer.get_name(x) << "++ = tmp_" << namer.get_name(x)
<< ";" << std::endl;
}
}
// Close loops
if (contiguous) {
os << " }" << std::endl;
} else {
for (int d = ndim - 1; d >= 0; --d) {
// Update pointers
for (auto& x : inputs) {
if (is_constant(x) || is_scalar(x)) {
continue;
}
auto& xname = namer.get_name(x);
os << " " << xname << " += " << xname << "_strides[" << d << "];"
<< std::endl;
if (d < ndim - 1) {
os << " " << xname << " -= " << xname << "_strides[" << d + 1 << "]"
<< " * shape[" << d + 1 << "];" << std::endl;
}
}
os << " }" << std::endl;
}
}
// Finish the kernel
os << "}" << std::endl;
}
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
bool contiguous = true;
{
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
}
}
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
std::vector<std::vector<size_t>> strides;
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
continue;
}
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
}
// Get the kernel name from the lib
int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) {
kernel_name += std::to_string(shape.size());
}
// Get the function
auto fn_ptr = compile(kernel_name);
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
build_kernel(
kernel,
kernel_name,
inputs_,
outputs_,
tape_,
constant_ids_,
contiguous,
ndim);
// Close extern "C"
kernel << "}" << std::endl;
// Compile and get function pointer
fn_ptr = compile(kernel_name, kernel.str());
}
// Allocate space for the outputs possibly with input donation
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
for (auto& x : outputs) {
args.push_back(x.data<void>());
}
if (!contiguous) {
args.push_back((void*)outputs[0].shape().data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
fun(args.data());
}
} // namespace mlx::core

View File

@@ -1,56 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <iomanip>
#include <sstream>
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/primitives.h"
namespace mlx::core {
inline bool is_static_cast(const Primitive& p) {
return (
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids);
std::string get_type_string(Dtype d);
template <typename T>
void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< x.item<T>() << std::setprecision(old_precision);
}
template <typename T>
void print_int_constant(std::ostream& os, const array& x) {
os << x.item<T>();
}
template <typename T>
void print_complex_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
T constant = x.item<T>();
os << get_type_string(x.dtype()) << "("
<< std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< constant.real() << ", " << constant.imag() << ")"
<< std::setprecision(old_precision);
}
void print_constant(std::ostream& os, const array& x);
inline bool is_scalar(const array& x) {
return x.ndim() == 0;
}
} // namespace mlx::core

View File

@@ -1,11 +0,0 @@
// Copyright © 2023-24 Apple Inc.
#pragma once
// clang-format off
#include "mlx/types/half_types.h"
#include "mlx/types/complex.h"
#include "mlx/backend/common/ops.h"
// clang-format on
const char* get_kernel_preamble();

View File

@@ -1,10 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <numeric>
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#include <vecLib/cblas_new.h>
#else
#include <cblas.h>
#endif
@@ -28,16 +27,14 @@ void slow_conv_1D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const std::vector<int>& wt_dilation) {
const T* start_wt_ptr = wt.data<T>();
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iH = in.shape(1); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
@@ -64,15 +61,12 @@ void slow_conv_1D(
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
int ih = oh * wt_strides[0] - padding[0] + wh * wt_dilation[0];
auto ih_div = std::div(ih, in_dilation[0]);
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
if (ih >= 0 && ih < iH) {
for (int c = 0; c < C; ++c) {
r += static_cast<float>(
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
in_ptr[ih * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[c * wt_stride_C]);
} // c
@@ -96,16 +90,14 @@ void slow_conv_2D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const std::vector<int>& wt_dilation) {
const T* st_wt_ptr = wt.data<T>();
const T* st_in_ptr = in.data<T>();
T* st_out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
@@ -128,8 +120,6 @@ void slow_conv_2D(
const size_t out_stride_W = out.strides()[2];
const size_t out_stride_O = out.strides()[3];
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
auto pt_conv_no_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
@@ -141,10 +131,8 @@ void slow_conv_2D(
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
int ih = ih_base + wh * wt_dilation[0];
int iw = iw_base + ww * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
@@ -165,74 +153,25 @@ void slow_conv_2D(
} // o
};
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int f_wgt_jump_h = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
base_h[i] = wh_base;
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
base_w[j] = ww_base;
}
auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int ih = ih_base + wh * wt_dilation[0];
int iw = iw_base + ww * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
@@ -252,17 +191,13 @@ void slow_conv_2D(
};
int oH_border_0 = 0;
int oH_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_1 = (padding[0] + wt_strides[0] + 1) / wt_strides[0];
int oH_border_2 = (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0];
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_1 = (padding[1] + wt_strides[0] + 1) / wt_strides[1];
int oW_border_2 = (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1];
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
@@ -311,18 +246,15 @@ void dispatch_slow_conv_1D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const std::vector<int>& wt_dilation) {
if (in.dtype() == float32) {
return slow_conv_1D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
return slow_conv_1D<float>(in, wt, out, padding, wt_strides, wt_dilation);
} else if (in.dtype() == float16) {
return slow_conv_1D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
in, wt, out, padding, wt_strides, wt_dilation);
} else if (in.dtype() == bfloat16) {
return slow_conv_1D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
in, wt, out, padding, wt_strides, wt_dilation);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
@@ -335,18 +267,15 @@ void dispatch_slow_conv_2D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const std::vector<int>& wt_dilation) {
if (in.dtype() == float32) {
return slow_conv_2D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
return slow_conv_2D<float>(in, wt, out, padding, wt_strides, wt_dilation);
} else if (in.dtype() == float16) {
return slow_conv_2D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
in, wt, out, padding, wt_strides, wt_dilation);
} else if (in.dtype() == bfloat16) {
return slow_conv_2D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
in, wt, out, padding, wt_strides, wt_dilation);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
@@ -564,16 +493,13 @@ void conv_1D_cpu(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
const std::vector<int>& wt_dilation) {
if (wt_dilation[0] == 1) {
return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation);
}
return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
return dispatch_slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation);
}
void conv_2D_cpu(
@@ -582,11 +508,8 @@ void conv_2D_cpu(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
const std::vector<int>& wt_dilation) {
return dispatch_slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation);
}
} // namespace
@@ -600,26 +523,12 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
// 2D convolution
if (in.ndim() == (2 + 2)) {
return conv_2D_cpu(
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
}
// 1D convolution
else if (in.ndim() == (1 + 2)) {
return conv_1D_cpu(
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
}
// Throw error
else {

View File

@@ -289,16 +289,11 @@ void copy(const array& src, array& dst, CopyType ctype) {
// Allocate the output
switch (ctype) {
case CopyType::Vector:
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
dst.copy_shared_buffer(src);
} else {
auto size = src.data_size();
dst.set_data(
allocator::malloc_or_wait(size * dst.itemsize()),
size,
src.strides(),
src.flags());
}
dst.set_data(
allocator::malloc_or_wait(src.data_size() * dst.itemsize()),
src.data_size(),
src.strides(),
src.flags());
break;
case CopyType::Scalar:
case CopyType::General:

View File

@@ -1,13 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#include <vecLib/cblas_new.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
@@ -41,15 +39,12 @@ DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(Remainder)
DEFAULT(Equal)
@@ -79,7 +74,6 @@ DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT(Power)
DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce)
@@ -87,7 +81,6 @@ DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Sin)
@@ -103,6 +96,7 @@ DEFAULT(Subtract)
DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)
namespace {
@@ -132,13 +126,6 @@ inline void matmul_common_general(
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(

11
mlx/backend/common/erf.h Normal file
View File

@@ -0,0 +1,11 @@
// Copyright © 2023 Apple Inc.
namespace mlx::core {
/* Approximation to the inverse error function.
* Based on code from:
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
*/
float erfinv(float a);
} // namespace mlx::core

View File

@@ -1,34 +0,0 @@
#!/bin/bash
#
# This script generates a C++ function that provides the CPU
# code for use with kernel generation.
#
# Copyright © 2023-24 Apple Inc.
OUTPUT_FILE=$1
GCC=$2
SRCDIR=$3
CLANG=$4
if [ $CLANG = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM
#include <cmath>
#include <complex>
#include <cstdint>
#include <vector>
EOM
fi
CONTENT=$($GCC -I $SRCDIR -E $SRCDIR/mlx/backend/common/compiled_preamble.h 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() {
return R"preamble(
$INCLUDES
$CONTENT
using namespace mlx::core::detail;
)preamble";
}
EOF

View File

@@ -1,602 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <stdint.h>
#include <cmath>
#include <complex>
namespace mlx::core::detail {
namespace {
constexpr float inf = std::numeric_limits<float>::infinity();
} // namespace
typedef union {
int i;
float f;
} IntOrFloat;
inline float fast_exp(float x) {
if (x == -std::numeric_limits<float>::infinity()) {
return 0.0f;
} else if (x == std::numeric_limits<float>::infinity() || std::isnan(x)) {
return x;
}
x *= 1.442695; // multiply with log_2(e)
float ipart, fpart;
IntOrFloat epart;
x = std::max(-80.f, std::min(x, 80.f));
ipart = std::floor(x + 0.5);
fpart = x - ipart;
x = 1.535336188319500e-4f;
x = x * fpart + 1.339887440266574e-3f;
x = x * fpart + 9.618437357674640e-3f;
x = x * fpart + 5.550332471162809e-2f;
x = x * fpart + 2.402264791363012e-1f;
x = x * fpart + 6.931472028550421e-1f;
x = x * fpart + 1.000000000000000f;
// generate 2**ipart in the floating point representation using integer
// bitshifting
epart.i = (int(ipart) + 127) << 23;
return epart.f * x;
}
inline float fast_erf(float a) {
float r, s, t, u;
t = std::abs(a);
s = a * a;
if (t > 0.927734375f) {
// maximum error 0.99527 ulp
r = std::fma(
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
u = std::fma(
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
r = std::fma(r, s, u);
r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
r = std::fma(r, t, -t);
// TODO, replace with expm1 when implemented
r = 1.0f - std::exp(r);
r = std::copysign(r, a);
} else {
// maximum error 0.98929 ulp
r = -5.96761703e-4f; // -0x1.38e000p-11
r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
r = std::fma(r, a, a);
}
return r;
}
inline float fast_erfinv(float a) {
auto t = std::fma(a, 0.0f - a, 1.0f);
t = std::log(t);
float p;
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
} else { // maximum ulp error = 2.35002
p = 5.43877832e-9f; // 0x1.75c000p-28
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
}
return a * p;
}
struct Abs {
template <typename T>
T operator()(T x) {
return std::abs(x);
};
uint8_t operator()(uint8_t x) {
return x;
};
uint16_t operator()(uint16_t x) {
return x;
};
uint32_t operator()(uint32_t x) {
return x;
};
uint64_t operator()(uint64_t x) {
return x;
};
bool operator()(bool x) {
return x;
};
};
struct ArcCos {
template <typename T>
T operator()(T x) {
return std::acos(x);
};
};
struct ArcCosh {
template <typename T>
T operator()(T x) {
return std::acosh(x);
};
};
struct ArcSin {
template <typename T>
T operator()(T x) {
return std::asin(x);
};
};
struct ArcSinh {
template <typename T>
T operator()(T x) {
return std::asinh(x);
};
};
struct ArcTan {
template <typename T>
T operator()(T x) {
return std::atan(x);
};
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
return std::atanh(x);
};
};
struct Ceil {
template <typename T>
T operator()(T x) {
return std::ceil(x);
};
int8_t operator()(int8_t x) {
return x;
};
int16_t operator()(int16_t x) {
return x;
};
int32_t operator()(int32_t x) {
return x;
};
int64_t operator()(int64_t x) {
return x;
};
uint8_t operator()(uint8_t x) {
return x;
};
uint16_t operator()(uint16_t x) {
return x;
};
uint32_t operator()(uint32_t x) {
return x;
};
uint64_t operator()(uint64_t x) {
return x;
};
bool operator()(bool x) {
return x;
};
};
struct Cos {
template <typename T>
T operator()(T x) {
return std::cos(x);
};
};
struct Cosh {
template <typename T>
T operator()(T x) {
return std::cosh(x);
};
};
struct Erf {
template <typename T>
T operator()(T x) {
return static_cast<T>(fast_erf(static_cast<float>(x)));
};
};
struct ErfInv {
template <typename T>
T operator()(T x) {
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
};
};
struct Exp {
template <typename T>
T operator()(T x) {
return fast_exp(x);
};
complex64_t operator()(complex64_t x) {
return std::exp(x);
}
};
struct Floor {
template <typename T>
T operator()(T x) {
return std::floor(x);
};
int8_t operator()(int8_t x) {
return x;
};
int16_t operator()(int16_t x) {
return x;
};
int32_t operator()(int32_t x) {
return x;
};
int64_t operator()(int64_t x) {
return x;
};
uint8_t operator()(uint8_t x) {
return x;
};
uint16_t operator()(uint16_t x) {
return x;
};
uint32_t operator()(uint32_t x) {
return x;
};
uint64_t operator()(uint64_t x) {
return x;
};
bool operator()(bool x) {
return x;
};
};
struct Log {
template <typename T>
T operator()(T x) {
return std::log(x);
};
};
struct Log2 {
template <typename T>
T operator()(T x) {
return std::log2(x);
};
};
struct Log10 {
template <typename T>
T operator()(T x) {
return std::log10(x);
};
};
struct Log1p {
template <typename T>
T operator()(T x) {
return log1p(x);
};
};
struct LogicalNot {
template <typename T>
T operator()(T x) {
return !x;
};
};
struct Negative {
template <typename T>
T operator()(T x) {
return -x;
};
};
struct Round {
template <typename T>
T operator()(T x) {
return std::rint(x);
}
complex64_t operator()(complex64_t x) {
return {std::rint(x.real()), std::rint(x.imag())};
}
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
auto one = static_cast<decltype(x)>(1.0);
return one / (one + fast_exp(-x));
}
};
struct Sign {
template <typename T>
T operator()(T x) {
return (x > T(0)) - (x < T(0));
}
uint8_t operator()(uint8_t x) {
return x != 0;
}
uint16_t operator()(uint16_t x) {
return x != 0;
}
uint32_t operator()(uint32_t x) {
return x != 0;
}
uint64_t operator()(uint64_t x) {
return x != 0;
}
};
struct Sin {
template <typename T>
T operator()(T x) {
return std::sin(x);
};
};
struct Sinh {
template <typename T>
T operator()(T x) {
return std::sinh(x);
};
};
struct Square {
template <typename T>
T operator()(T x) {
return x * x;
};
};
struct Sqrt {
template <typename T>
T operator()(T x) {
return std::sqrt(x);
};
};
struct Rsqrt {
template <typename T>
T operator()(T x) {
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
};
};
struct Tan {
template <typename T>
T operator()(T x) {
return std::tan(x);
};
};
struct Tanh {
template <typename T>
T operator()(T x) {
return std::tanh(x);
};
};
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
template <typename T>
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = numerator % denominator;
if (r != 0 && (r < 0 != denominator < 0))
r += denominator;
return r;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = std::fmod(numerator, denominator);
if (r != 0 && (r < 0 != denominator < 0)) {
r += denominator;
}
return r;
}
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
return numerator % denominator;
}
};
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
return x == y || (std::isnan(x) && std::isnan(y));
}
};
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct Maximum {
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
return (x > y) ? x : y;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
if (std::isnan(x)) {
return x;
}
return (x > y) ? x : y;
}
};
struct Minimum {
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
return x < y ? x : y;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
if (std::isnan(x)) {
return x;
}
return x < y ? x : y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
constexpr float inf = std::numeric_limits<float>::infinity();
auto maxval = Maximum()(x, y);
auto minval = Minimum()(x, y);
return (minval == -inf || maxval == inf)
? maxval
: static_cast<decltype(x)>(
maxval + std::log1p(fast_exp(minval - maxval)));
};
};
struct Multiply {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
};
struct Power {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
return std::pow(base, exp);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
};
struct Subtract {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
};
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
};
};
struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return condition ? x : y;
}
};
} // namespace mlx::core::detail

View File

@@ -10,7 +10,7 @@
#include "mlx/backend/common/arange.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/erf.h"
#include "mlx/backend/common/threefry.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/common/utils.h"
@@ -26,7 +26,7 @@ void Abs::eval(const std::vector<array>& inputs, array& out) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Abs());
unary(in, out, AbsOp());
}
}
@@ -38,7 +38,7 @@ void ArcCos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::ArcCos());
unary_fp(in, out, [](auto x) { return std::acos(x); });
} else {
throw std::invalid_argument(
"[arccos] Cannot compute inverse cosine of elements in array"
@@ -50,7 +50,7 @@ void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::ArcCosh());
unary_fp(in, out, [](auto x) { return std::acosh(x); });
} else {
throw std::invalid_argument(
"[arccosh] Cannot compute inverse hyperbolic cosine of elements in"
@@ -62,7 +62,7 @@ void ArcSin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::ArcSin());
unary_fp(in, out, [](auto x) { return std::asin(x); });
} else {
throw std::invalid_argument(
"[arcsin] Cannot compute inverse sine of elements in array"
@@ -74,7 +74,7 @@ void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::ArcSinh());
unary_fp(in, out, [](auto x) { return std::asinh(x); });
} else {
throw std::invalid_argument(
"[arcsinh] Cannot compute inverse hyperbolic sine of elements in"
@@ -86,7 +86,7 @@ void ArcTan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::ArcTan());
unary_fp(in, out, [](auto x) { return std::atan(x); });
} else {
throw std::invalid_argument(
"[arctan] Cannot compute inverse tangent of elements in array"
@@ -98,7 +98,7 @@ void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::ArcTanh());
unary_fp(in, out, [](auto x) { return std::atanh(x); });
} else {
throw std::invalid_argument(
"[arctanh] Cannot compute inverse hyperbolic tangent of elements in"
@@ -172,7 +172,7 @@ void Ceil::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
unary_fp(in, out, detail::Ceil());
unary_fp(in, out, [](auto x) { return std::ceil(x); });
} else {
// No-op integer types
out.copy_shared_buffer(in);
@@ -212,7 +212,7 @@ void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Cos());
unary_fp(in, out, [](auto x) { return std::cos(x); });
} else {
throw std::invalid_argument(
"[cos] Cannot compute cosine of elements in array"
@@ -224,7 +224,7 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Cosh());
unary_fp(in, out, [](auto x) { return std::cosh(x); });
} else {
throw std::invalid_argument(
"[cosh] Cannot compute hyperbolic cosine of elements in array"
@@ -232,37 +232,25 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
}
}
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::Erf());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float>(in, out, [](auto x) { return std::erf(x); });
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float16_t>(in, out, [](auto x) {
return static_cast<float16_t>(std::erf(static_cast<float>(x)));
});
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<bfloat16_t>(in, out, [](auto x) {
return static_cast<bfloat16_t>(std::erf(static_cast<float>(x)));
});
break;
default:
throw std::invalid_argument(
@@ -276,13 +264,20 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::ErfInv());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float>(in, out, [](auto x) { return erfinv(x); });
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float16_t>(in, out, [](auto x) {
return static_cast<float16_t>(erfinv(static_cast<float>(x)));
});
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<bfloat16_t>(in, out, [](auto x) {
return static_cast<bfloat16_t>(erfinv(static_cast<float>(x)));
});
break;
default:
throw std::invalid_argument(
@@ -294,8 +289,9 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
void Exp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Exp());
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
throw std::invalid_argument(
"[exp] Cannot exponentiate elements in array"
@@ -307,7 +303,7 @@ void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
unary_fp(in, out, detail::Floor());
unary_fp(in, out, [](auto x) { return std::floor(x); });
} else {
// No-op integer types
out.copy_shared_buffer(in);
@@ -335,13 +331,13 @@ void Log::eval(const std::vector<array>& inputs, array& out) {
if (is_floating_point(out.dtype())) {
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log());
unary_fp(in, out, [](auto x) { return std::log(x); });
break;
case Base::two:
unary_fp(in, out, detail::Log2());
unary_fp(in, out, [](auto x) { return std::log2(x); });
break;
case Base::ten:
unary_fp(in, out, detail::Log10());
unary_fp(in, out, [](auto x) { return std::log10(x); });
break;
}
} else {
@@ -355,7 +351,7 @@ void Log1p::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Log1p());
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
throw std::invalid_argument(
"[log1p] Cannot compute log of elements in array with"
@@ -366,27 +362,27 @@ void Log1p::eval(const std::vector<array>& inputs, array& out) {
void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::LogicalNot());
unary(in, out, [](auto x) { return !x; });
}
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
binary(in1, in2, out, [](auto x, auto y) { return x && y; });
}
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
binary(in1, in2, out, [](auto x, auto y) { return x || y; });
}
void Negative::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Negative());
unary(in, out, [](auto x) { return -x; });
}
void Pad::eval(const std::vector<array>& inputs, array& out) {
@@ -489,7 +485,7 @@ void Round::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (not is_integral(in.dtype())) {
unary_fp(in, out, detail::Round());
unary_fp(in, out, RoundOp());
} else {
// No-op integer types
out.copy_shared_buffer(in);
@@ -500,7 +496,11 @@ void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Sigmoid());
auto sigmoid_op = [](auto x) {
auto one = static_cast<decltype(x)>(1.0);
return one / (one + std::exp(-x));
};
unary_fp(in, out, sigmoid_op);
} else {
throw std::invalid_argument(
"[sigmoid] Cannot sigmoid of elements in array with"
@@ -514,7 +514,7 @@ void Sign::eval(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) {
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Sign());
unary(in, out, SignOp());
}
}
@@ -522,7 +522,7 @@ void Sin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Sin());
unary_fp(in, out, [](auto x) { return std::sin(x); });
} else {
throw std::invalid_argument(
"[sin] Cannot compute sine of elements in array"
@@ -534,7 +534,7 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Sinh());
unary_fp(in, out, [](auto x) { return std::sinh(x); });
} else {
throw std::invalid_argument(
"[sinh] Cannot compute hyperbolic sine of elements in array"
@@ -643,16 +643,18 @@ void Split::eval(
void Square::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Square());
unary(in, out, [](auto x) { return x * x; });
}
void Sqrt::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (recip_) {
unary_fp(in, out, detail::Rsqrt());
unary_fp(in, out, [](auto x) {
return static_cast<decltype(x)>(1.0) / sqrt(x);
});
} else {
unary_fp(in, out, detail::Sqrt());
unary_fp(in, out, [](auto x) { return sqrt(x); });
}
}
@@ -665,7 +667,7 @@ void Tan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Tan());
unary_fp(in, out, [](auto x) { return std::tan(x); });
} else {
throw std::invalid_argument(
"[tan] Cannot compute tangent of elements in array"
@@ -677,7 +679,7 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Tanh());
unary_fp(in, out, [](auto x) { return std::tanh(x); });
} else {
throw std::invalid_argument(
"[tanh] Cannot compute hyperbolic tangent of elements in array"

View File

@@ -1,153 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
template <typename T>
struct lpack;
template <>
struct lpack<float> {
static void xgeqrf(
const int* m,
const int* n,
float* a,
const int* lda,
float* tau,
float* work,
const int* lwork,
int* info) {
sgeqrf_(m, n, a, lda, tau, work, lwork, info);
}
static void xorgqr(
const int* m,
const int* n,
const int* k,
float* a,
const int* lda,
const float* tau,
float* work,
const int* lwork,
int* info) {
sorgqr_(m, n, k, a, lda, tau, work, lwork, info);
}
};
template <typename T>
void qrf_impl(const array& a, array& q, array& r) {
const int M = a.shape(-2);
const int N = a.shape(-1);
const int lda = std::max(M, N);
size_t num_matrices = a.size() / (M * N);
int num_reflectors = std::min(M, N);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
// Copy A to inplace input and make it col-contiguous
array in(a.shape(), float32, nullptr, {});
auto flags = in.flags();
// Copy the input to be column contiguous
flags.col_contiguous = num_matrices == 1;
flags.row_contiguous = false;
std::vector<size_t> strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral);
T optimal_work;
int lwork = -1;
int info;
// Compute workspace size
lpack<T>::xgeqrf(
&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Solve
lpack<T>::xgeqrf(
&M,
&N,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
allocator::free(work);
r.set_data(allocator::malloc_or_wait(r.nbytes()));
copy_inplace(in, r, CopyType::General);
for (int i = 0; i < num_matrices; ++i) {
// Zero lower triangle
for (int j = 0; j < r.shape(-2); ++j) {
for (int k = 0; k < j; ++k) {
r.data<T>()[i * N * M + j * N + k] = 0;
}
}
}
// Get work size
lwork = -1;
lpack<T>::xorgqr(
&M,
&N,
&num_reflectors,
nullptr,
&lda,
nullptr,
&optimal_work,
&lwork,
&info);
lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
lpack<T>::xorgqr(
&M,
&N,
&num_reflectors,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
q.set_data(allocator::malloc_or_wait(q.nbytes()));
copy_inplace(in, q, CopyType::General);
// Cleanup
allocator::free(work);
allocator::free(tau);
}
void QRF::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[QRF::eval] only supports float32.");
}
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
}
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/backend/metal/copy.h"
#include "mlx/primitives.h"
@@ -118,12 +119,6 @@ void _qmm_dispatch_typed(
switch (bits) {
case 2: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
@@ -140,12 +135,6 @@ void _qmm_dispatch_typed(
}
case 4: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
@@ -162,12 +151,6 @@ void _qmm_dispatch_typed(
}
case 8: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);

View File

@@ -1,13 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast_primitives.h"
namespace mlx::core::fast {
void RoPE::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("NYI");
}
} // namespace mlx::core::fast

View File

@@ -1,72 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/common/ternary.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename Op>
void select_op(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
switch (out.dtype()) {
case bool_:
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
break;
case uint8:
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
break;
case uint16:
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
break;
case uint32:
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
break;
case uint64:
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
break;
case int8:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
break;
case int16:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
break;
case int32:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
break;
case int64:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
break;
case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
break;
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
break;
case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
break;
}
}
} // namespace
void Select::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
const auto& condition = inputs[0];
const auto& a = inputs[1];
const auto& b = inputs[2];
select_op(condition, a, b, out, detail::Select());
}
} // namespace mlx::core

View File

@@ -53,12 +53,7 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.strides().back() == 1) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});

View File

@@ -1,226 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
namespace {
// TODO: Add support for more combinations of input types.
enum class TernaryOpType {
ScalarScalarScalar,
General,
};
TernaryOpType
get_ternary_op_type(const array& a, const array& b, const array& c) {
TernaryOpType topt;
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
topt = TernaryOpType::ScalarScalarScalar;
} else {
topt = TernaryOpType::General;
}
return topt;
}
void set_ternary_op_output_data(
const array& a,
const array& b,
const array& c,
array& out,
TernaryOpType topt,
bool donate_with_move = false) {
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
break;
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims1(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
for (size_t i = 0; i < out.size(); ++i) {
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
c_idx += c.strides()[0];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims2(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
c_idx += c.strides()[1];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims3(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
c_idx += c.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims4(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
c_idx += c.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dispatch_dims(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
switch (out.ndim()) {
case 1:
ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 2:
ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 3:
ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 4:
ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
}
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
int c_idx = elem_to_loc(i, c.shape(), c.strides());
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
// The full computation is scalar-scalar-scalar so we call the base op once.
if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
return;
}
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
}
} // namespace
} // namespace mlx::core

View File

@@ -11,24 +11,68 @@ namespace mlx::core {
namespace {
void set_unary_output_data(const array& in, array& out) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
struct AbsOp {
template <typename T>
T operator()(T x) {
return std::abs(x);
}
}
uint8_t operator()(uint8_t x) {
return x;
}
uint16_t operator()(uint16_t x) {
return x;
}
uint32_t operator()(uint32_t x) {
return x;
}
uint64_t operator()(uint64_t x) {
return x;
}
bool operator()(bool x) {
return x;
}
};
struct SignOp {
template <typename T>
T operator()(T x) {
return (x > T(0)) - (x < T(0));
}
uint8_t operator()(uint8_t x) {
return x != 0;
}
uint16_t operator()(uint16_t x) {
return x != 0;
}
uint32_t operator()(uint32_t x) {
return x != 0;
}
uint64_t operator()(uint64_t x) {
return x != 0;
}
};
struct RoundOp {
template <typename T>
T operator()(T x) {
return std::rint(x);
}
complex64_t operator()(complex64_t x) {
return {std::rint(x.real()), std::rint(x.imag())};
}
};
template <typename T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
if (a.flags().contiguous) {
set_unary_output_data(a, out);
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
T* dst = out.data<T>();
for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]);

View File

@@ -1,28 +1,7 @@
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR}
DEPENDS make_compiled_preamble.sh
kernels/compiled_preamble.h
kernels/unary.h
kernels/binary.h
)
add_custom_target(
compiled_preamble
DEPENDS compiled_preamble.cpp
)
add_dependencies(mlx compiled_preamble)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
@@ -32,12 +11,10 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)
if (NOT MLX_METAL_PATH)

View File

@@ -1,378 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <sstream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/compiled_preamble.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
inline void build_kernel(
std::ostream& os,
const std::string& kernel_name,
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous,
int ndim,
bool dynamic_dims) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer;
bool add_indices = false;
int cnt = 0;
// Start the kernel
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl
<< "[[kernel]] void " << kernel_name << "(" << std::endl;
// Add the input arguments
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
// Skip constants from the input list
if (is_constant(x)) {
continue;
}
// Scalars and contiguous need no strides
if (is_scalar(x) || contiguous) {
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
} else {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl
<< " constant const size_t* " << xname << "_strides [[buffer("
<< cnt++ << ")]]," << std::endl;
}
}
// Add the output arguments
for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* "
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " constant const size_t* output_strides [[buffer(" << cnt++
<< ")]]," << std::endl
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],"
<< std::endl;
}
if (dynamic_dims) {
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],"
<< std::endl;
}
// The thread index in the whole grid
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
<< " uint3 grid [[threads_per_grid]]) {" << std::endl
<< " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"
<< std::endl;
// Extract the indices per axis to individual uints if we have arrays that
// are broadcasted or transposed
if (add_indices) {
if (!dynamic_dims) {
if (ndim == 1) {
os << " uint index_0 = pos.x;" << std::endl;
} else if (ndim == 2) {
os << " uint index_0 = pos.y;" << std::endl
<< " uint index_1 = pos.x;" << std::endl;
} else if (ndim == 3) {
os << " uint index_0 = pos.z;" << std::endl
<< " uint index_1 = pos.y;" << std::endl
<< " uint index_2 = pos.x;" << std::endl;
} else {
for (int i = 0; i < ndim - 2; i++) {
os << " uint index_" << i << " = (index / uint(output_strides[" << i
<< "])) % output_shape[" << i << "];" << std::endl;
}
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
}
}
}
// Read the inputs in tmps
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ";" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];" << std::endl;
} else if (!dynamic_dims) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[";
os << "index_0 * " << xname << "_strides[0]";
for (int i = 1; i < ndim; i++) {
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
}
os << "];" << std::endl;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[elem_to_loc(index, output_shape, " << xname
<< "_strides, ndim)];" << std::endl;
}
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
<< ";" << std::endl;
}
// Finish the kernel
os << "}" << std::endl;
if (cnt > 31) {
std::ostringstream msg;
msg << "[compile] Too many inputs/outputs fused in the Metal Compiled "
<< "primitive which exhausted the available argument buffers for "
<< "the kernel. Please file an issue with the function that results "
<< "in this error. The name of the kernel is '" << kernel_name << "'";
throw std::runtime_error(msg.str());
}
}
void Compiled::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Make the name for the kernel library
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Get the kernel if someone else built it already
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_);
// If not we have to build it ourselves
if (lib == nullptr) {
std::ostringstream kernel;
kernel << metal::get_kernel_preamble() << std::endl;
build_kernel(
kernel,
kernel_lib_ + "_contiguous",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false);
for (int i = 1; i < 8; i++) {
build_kernel(
kernel,
kernel_lib_ + "_strided_" + std::to_string(i),
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false);
}
build_kernel(
kernel,
kernel_lib_ + "_strided_dynamic",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true);
lib = d.get_library(kernel_lib_, kernel.str());
}
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
bool contiguous = true;
for (auto& x : inputs) {
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
!is_scalar(x)) {
contiguous = false;
break;
}
}
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
std::vector<std::vector<size_t>> initial_strides;
initial_strides.push_back(outputs[0].strides());
std::vector<int> shape;
std::vector<std::vector<size_t>> strides;
if (!contiguous) {
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
// Skip scalar inputs.
if (is_scalar(x)) {
continue;
}
// Broadcast the inputs to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < output_shape.size() - x.ndim(); j++) {
if (output_shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (output_shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
initial_strides.push_back(std::move(xstrides));
}
std::tie(shape, strides) =
collapse_contiguous_dims(output_shape, initial_strides);
}
// Get the kernel from the lib
int ndim = shape.size();
bool dynamic = ndim >= 8;
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) {
if (dynamic) {
kernel_name += "dynamic";
} else {
kernel_name += std::to_string(shape.size());
}
}
auto kernel = d.get_kernel(kernel_name, lib);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Put the inputs in
int cnt = 0;
int stride_idx = 1; // idx 0 is the output strides
for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
set_array_buffer(compute_encoder, x, cnt++);
if (!contiguous && !is_scalar(x)) {
compute_encoder->setBytes(
strides[stride_idx].data(),
strides[stride_idx].size() * sizeof(size_t),
cnt++);
stride_idx++;
}
}
// Allocate space for the outputs possibly with input donation
{
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].move_shared_buffer(in);
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
// Put the outputs in
for (auto& x : outputs) {
set_array_buffer(compute_encoder, x, cnt++);
}
// Put the output shape and strides in
if (!contiguous) {
compute_encoder->setBytes(
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
}
// Put the number of dims in if it is dynamic
if (dynamic) {
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
}
// Launch the kernel
if (contiguous) {
size_t nthreads = outputs[0].size();
MTL::Size grid_dims(nthreads, 1, 1);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
} // namespace mlx::core

View File

@@ -1,9 +0,0 @@
// Copyright © 2023-24 Apple Inc.
#pragma once
namespace mlx::core::metal {
const char* get_kernel_preamble();
}

View File

@@ -1,78 +1,87 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <iostream>
#include <numeric>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/conv_params.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/matmul.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
using namespace mlx::steel;
namespace mlx::core {
namespace {
template <int N>
void explicit_gemm_conv_ND_gpu(
void explicit_gemm_conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<N>& conv_params) {
// Prepare unfolding array
std::vector<int> unfolded_shape = {
static_cast<int>(out.size() / conv_params.O),
static_cast<int>(wt.size() / conv_params.O)};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
const MLXConvParams<1>& conv_params) {
// Pad input
std::vector<int> padded_shape = {
conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.C};
array in_padded(padded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
// Fill with zeros
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, in_unfolded, 1);
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
// Make strided view
std::vector<int> strided_shape = {
conv_params.N, conv_params.oS[0], conv_params.wS[0], conv_params.C};
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x;
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * conv_params.str[0],
in_padded.strides()[1],
in_padded.strides()[2]};
auto flags = in_padded.flags();
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Materialize strided view
std::vector<int> strided_reshape = {
conv_params.N * conv_params.oS[0], conv_params.wS[0] * conv_params.C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
// Perform gemm
std::vector<array> copies;
std::vector<array> copies = {in_padded, in_strided};
return steel_matmul(
s,
d,
/*a = */ in_unfolded,
/*a = */ in_strided,
/*b = */ wt,
/*c = */ out,
/*M = */ unfolded_shape[0],
/*M = */ strided_reshape[0],
/*N = */ conv_params.O,
/*K = */ unfolded_shape[1],
/*K = */ strided_reshape[1],
/*batch_size_out = */ 1,
/*a_cols = */ unfolded_shape[1],
/*b_cols = */ unfolded_shape[1],
/*a_cols = */ strided_reshape[1],
/*b_cols = */ strided_reshape[1],
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
@@ -86,9 +95,7 @@ void conv_1D_gpu(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const std::vector<int>& wt_dilation) {
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ in.shape(0),
@@ -99,19 +106,24 @@ void conv_1D_gpu(
/* const int oS[NDIM] = */ {out.shape(1)},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
/* const int idil[NDIM] = */ {in_dilation[0]},
/* const int dil[NDIM] = */ {wt_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ 1,
/* const bool flip = */ flip};
};
// Direct to explicit gemm conv
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
if (wt_dilation[0] == 1) {
explicit_gemm_conv_1D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to fallback conv
else {
throw std::invalid_argument("[conv_1D_gpu] Dilation needs to be 1.");
}
}
void slow_conv_2D_gpu(
@@ -157,262 +169,114 @@ void implicit_gemm_conv_2D_gpu(
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
// Deduce implicit gemm size
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
// Determine block and warp tiles
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
int bk = 16;
if (implicit_N <= 16) {
bn = 8;
wm = 4;
wn = 1;
}
int tn = (implicit_N + bn - 1) / bn;
int tm = (implicit_M + bm - 1) / bm;
int swizzle_log = 0;
// Fix small channel specialization
int n_channel_specialization = 0;
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
if (conv_params.C <= 2) {
gemm_k_iters = (implicit_K + bk - 1) / bk;
n_channel_specialization = conv_params.C;
} else if (conv_params.C <= 4) {
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
n_channel_specialization = conv_params.C;
}
bool small_filter = (!n_channel_specialization) &&
(conv_params.wS[0] <= 16 && conv_params.wS[1] <= 16);
// Fix host side helper params
int sign = (conv_params.flip ? -1 : 1);
int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
int inp_jump_w = sign * ijw;
int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
sign * (conv_params.wS[1] - 1) * ijw;
// Build implicit gemm params
ImplicitGemmConv2DParams gemm_params{
/* const int M = */ implicit_M,
/* const int N = */ implicit_N,
/* const int K = */ implicit_K,
/* const int gemm_k_iterations = */ gemm_k_iters,
/* const int inp_jump_w = */ inp_jump_w,
/* const int inp_jump_h = */ inp_jump_h,
/* const int inp_jump_c = */ inp_jump_c,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int swizzle_log = */ swizzle_log};
// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
: "l")
<< "_filter_" << (small_filter ? 's' : 'l');
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
size_t grid_dim_y = (tm + tile - 1) / tile;
size_t grid_dim_x = tn * tile;
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
size_t grid_dim_x = (implicit_N + bn - 1) / bn;
size_t grid_dim_y = (implicit_M + bm - 1) / bm;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_general_gpu(
void explicit_gemm_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
// Deduce implicit gemm size
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
// Pad input
std::vector<int> padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
conv_params.C};
array in_padded(padded_shape, in.dtype(), nullptr, {});
// Determine block and warp tiles
int wm = 2, wn = 2;
// Fill with zeros
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
// Make jump params
int f_wgt_jump_h =
std::lcm(conv_params.idil[0], conv_params.kdil[0]) / conv_params.kdil[0];
int f_wgt_jump_w =
std::lcm(conv_params.idil[1], conv_params.kdil[1]) / conv_params.kdil[1];
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
conv_params.pad[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
int f_out_jump_h =
std::lcm(conv_params.idil[0], conv_params.str[0]) / conv_params.str[0];
int f_out_jump_w =
std::lcm(conv_params.idil[1], conv_params.str[1]) / conv_params.str[1];
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
int adj_out_h = (conv_params.oS[0] + f_out_jump_h - 1) / f_out_jump_h;
int adj_out_w = (conv_params.oS[1] + f_out_jump_w - 1) / f_out_jump_w;
int adj_out_hw = adj_out_h * adj_out_w;
int adj_implicit_m = conv_params.N * adj_out_hw;
// Make strided view
std::vector<int> strided_shape = {
conv_params.N,
conv_params.oS[0],
conv_params.oS[1],
conv_params.wS[0],
conv_params.wS[1],
conv_params.C};
Conv2DGeneralJumpParams jump_params{
/* const int f_wgt_jump_h = */ f_wgt_jump_h,
/* const int f_wgt_jump_w = */ f_wgt_jump_w,
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * conv_params.str[0],
in_padded.strides()[2] * conv_params.str[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
/* const int f_out_jump_h = */ f_out_jump_h,
/* const int f_out_jump_w = */ f_out_jump_w,
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
/* const int adj_out_h = */ adj_out_h,
/* const int adj_out_w = */ adj_out_w,
/* const int adj_out_hw = */ adj_out_hw,
/* const int adj_implicit_m = */ adj_implicit_m};
// Materialize strided view
std::vector<int> strided_reshape = {
conv_params.N * conv_params.oS[0] * conv_params.oS[1],
conv_params.wS[0] * conv_params.wS[1] * conv_params.C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
// Make base info
std::vector<Conv2DGeneralBaseInfo> base_h(f_out_jump_h);
std::vector<Conv2DGeneralBaseInfo> base_w(f_out_jump_w);
int jump_h = conv_params.flip ? -conv_params.kdil[0] : conv_params.kdil[0];
int jump_w = conv_params.flip ? -conv_params.kdil[1] : conv_params.kdil[1];
int init_h =
(conv_params.flip ? (conv_params.wS[0] - 1) * conv_params.kdil[0] : 0);
int init_w =
(conv_params.flip ? (conv_params.wS[1] - 1) * conv_params.kdil[1] : 0);
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * conv_params.str[0] - conv_params.pad[0] + init_h;
int wh_base = 0;
while (wh_base < conv_params.wS[0] && ih_loop % conv_params.idil[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
int wh_size =
((conv_params.wS[0] - wh_base) + f_wgt_jump_h - 1) / f_wgt_jump_h;
base_h[i] = {wh_base, wh_size};
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * conv_params.str[1] - conv_params.pad[1] + init_w;
int ww_base = 0;
while (ww_base < conv_params.wS[1] && iw_loop % conv_params.idil[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
int ww_size =
((conv_params.wS[1] - ww_base) + f_wgt_jump_w - 1) / f_wgt_jump_w;
base_w[j] = {ww_base, ww_size};
}
// Collect block sizes
int bm = adj_implicit_m >= 8192 && conv_params.C >= 64 ? 64 : 32;
int bn = (bm == 64 && implicit_N >= 64) ? 64 : 32;
int bk = 16;
int tn = (implicit_N + bn - 1) / bn;
int tm = (adj_implicit_m + bm - 1) / bm;
int swizzle_log = 0;
// Get channel iteration info
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
int gemm_k_iters = channel_k_iters;
// Fix host side helper params
int sign = (conv_params.flip ? -1 : 1);
int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
int inp_jump_w = sign * ijw;
int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
sign * (conv_params.wS[1] - 1) * ijw;
// Build implicit gemm params
ImplicitGemmConv2DParams gemm_params{
/* const int M = */ implicit_M,
/* const int N = */ implicit_N,
/* const int K = */ implicit_K,
/* const int gemm_k_iterations = */ gemm_k_iters,
/* const int inp_jump_w = */ inp_jump_w,
/* const int inp_jump_h = */ inp_jump_h,
/* const int inp_jump_c = */ inp_jump_c,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int swizzle_log = */ swizzle_log};
// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
size_t grid_dim_y = (tm + tile - 1) / tile;
size_t grid_dim_x = tn * tile;
size_t grid_dim_z = f_out_jump_h * f_out_jump_w;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
compute_encoder->setBytes(
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
compute_encoder->setBytes(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
return steel_matmul(
s,
d,
/*a = */ in_strided,
/*b = */ wt,
/*c = */ out,
/*M = */ strided_reshape[0],
/*N = */ conv_params.O,
/*K = */ strided_reshape[1],
/*batch_size_out = */ 1,
/*a_cols = */ strided_reshape[1],
/*b_cols = */ strided_reshape[1],
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
}
void winograd_conv_2D_gpu(
@@ -437,7 +301,6 @@ void winograd_conv_2D_gpu(
// Fill with zeros
array zero_arr = array(0, in.dtype());
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
copies_w.push_back(zero_arr);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
@@ -466,8 +329,7 @@ void winograd_conv_2D_gpu(
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int str[NDIM] = */ {1, 1},
/* const int pad[NDIM] = */ {0, 0},
/* const int kdil[NDIM] = */ {1, 1},
/* const int idil[NDIM] = */ {1, 1},
/* const int dil[NDIM] = */ {1, 1},
/* const size_t in_strides[NDIM + 2] = */
{in_padded.strides()[0],
in_padded.strides()[1],
@@ -477,8 +339,6 @@ void winograd_conv_2D_gpu(
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
/* const int groups = */ 1,
/* const bool flip = */ false,
};
int O_c = conv_params.O;
@@ -602,8 +462,6 @@ void conv_2D_gpu(
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip,
std::vector<array>& copies) {
// Make conv params
MLXConvParams<2> conv_params{
@@ -615,47 +473,37 @@ void conv_2D_gpu(
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
/* const int pad[NDIM] = */ {padding[0], padding[1]},
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
/* const int dil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
/* const int groups = */ 1,
/* const bool flip = */ flip,
};
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
bool inp_large = (conv_params.in_strides[0] >= 1ul << 18);
bool channels_large = (conv_params.C + conv_params.O) >= 512;
bool channels_med = (conv_params.C + conv_params.O) >= 256;
// Direct to winograd conv
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
(channels_large || (channels_med && inp_large))) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
conv_params.C >= 64 && conv_params.O >= 64 && conv_params.wS[0] == 3 &&
conv_params.wS[1] == 3 && conv_params.str[0] == 1 &&
conv_params.str[1] == 1 && conv_params.dil[0] == 1 &&
conv_params.dil[1] == 1) {
winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
else if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0) {
implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else if (wt_dilation[0] == 1 && wt_dilation[1] == 1) {
explicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to fallback conv
else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
slow_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
}
@@ -686,31 +534,11 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
// 2D conv
if (out.ndim() == 4) {
conv_2D_gpu(
s,
d,
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_,
copies);
s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_, copies);
}
// 1D conv
else if (out.ndim() == 3) {
conv_1D_gpu(
s,
d,
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
conv_1D_gpu(s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_);
}
// Throw error
else {

View File

@@ -12,15 +12,11 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
@@ -71,8 +67,7 @@ void copy_gpu_inplace(
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-24 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <dlfcn.h>
#include <cstdlib>
@@ -26,8 +26,7 @@ static constexpr const char* default_mtllib_path = METAL_PATH;
auto load_device() {
auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0))
?: MTL::CreateSystemDefaultDevice();
auto device = static_cast<MTL::Device*>(devices->object(0));
if (!device) {
throw std::runtime_error("Failed to load device");
}
@@ -215,6 +214,15 @@ MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
return eit->second;
}
MTL::ArgumentEncoder* Device::argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const {
// NB array here is already autoreleased but the returned argument
// encoder is owned by the caller and must be released/autoreleased
NS::Array* arg_desc_arr = NS::Array::array(
reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size());
return device_->newArgumentEncoder(arg_desc_arr);
}
void Device::register_library(
const std::string& lib_name,
const std::string& lib_path) {
@@ -234,127 +242,37 @@ void Device::register_library(
}
}
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
MTL::ComputePipelineState* Device::get_kernel(
const std::string& name,
const std::string& lib_name /* = "mlx" */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
return it->second;
}
// Prepare new kernel
// Search for cached metal lib
MTL::Library* mtl_lib;
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name);
mtl_lib = library_map_[lib_name];
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const std::string& source_string) {
auto pool = new_scoped_memory_pool();
auto ns_code =
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build metal library from source"
<< "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
auto pool = new_scoped_memory_pool();
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(desc, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build stitched metal library"
<< "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return mtl_lib;
}
MTL::Function* Device::get_function_(
const std::string& name,
MTL::Library* mtl_lib) {
// Pull kernel from library
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
auto mtl_function = mtl_lib->newFunction(ns_name);
return mtl_function;
}
MTL::Function* Device::get_function_(
const std::string& name,
const std::string& specialized_name,
const MTLFCList& func_consts,
MTL::Library* mtl_lib) {
if (func_consts.empty() && (specialized_name == name)) {
return get_function_(name, mtl_lib);
}
// Prepare function constants
auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init();
for (auto [value, type, index] : func_consts) {
mtl_func_consts->setConstantValue(value, type, index);
}
// Prepare function desc
auto desc = MTL::FunctionDescriptor::functionDescriptor();
desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding));
desc->setSpecializedName(
NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding));
desc->setConstantValues(mtl_func_consts);
// Pull kernel from library
NS::Error* error = nullptr;
auto mtl_function = mtl_lib->newFunction(desc, &error);
// Throw error if unable to build metal function
if (!mtl_function) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load function " << name << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
mtl_func_consts->release();
desc->release();
return mtl_function;
}
MTL::ComputePipelineState* Device::get_kernel_(
const std::string& name,
const MTL::Function* mtl_function) {
// Compile kernel to compute pipeline
NS::Error* error = nullptr;
MTL::ComputePipelineState* kernel;
if (mtl_function) {
kernel = device_->newComputePipelineState(mtl_function, &error);
mtl_function->release();
}
// Throw error if unable to compile metal function
if (!mtl_function || !kernel) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
@@ -364,175 +282,11 @@ MTL::ComputePipelineState* Device::get_kernel_(
throw std::runtime_error(msg.str());
}
return kernel;
}
MTL::ComputePipelineState* Device::get_kernel_(
const std::string& name,
const MTL::Function* mtl_function,
const MTL::LinkedFunctions* linked_functions) {
// Check inputs
if (!linked_functions) {
return get_kernel_(name, mtl_function);
}
if (!mtl_function) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
throw std::runtime_error(msg.str());
}
// Prepare compute pipeline state descriptor
auto desc = MTL::ComputePipelineDescriptor::alloc()->init();
desc->setComputeFunction(mtl_function);
desc->setLinkedFunctions(linked_functions);
// Compile kernel to compute pipeline
NS::Error* error = nullptr;
auto kernel = device_->newComputePipelineState(
desc, MTL::PipelineOptionNone, nullptr, &error);
// Throw error if unable to compile metal function
if (!kernel) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return kernel;
}
MTL::Library* Device::get_library(const std::string& name) {
auto it = library_map_.find(name);
return (it != library_map_.end()) ? it->second : nullptr;
}
MTL::Library* Device::get_library(
const std::string& name,
const std::string& source,
bool cache /* = true */) {
if (cache) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(source);
if (cache) {
library_map_.insert({name, mtl_lib});
}
return mtl_lib;
}
MTL::Library* Device::get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache /* = true */) {
if (cache) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(desc);
if (cache) {
library_map_.insert({name, mtl_lib});
}
return mtl_lib;
}
MTL::Function* Device::get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
}
MTL::Function* Device::get_function(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
return get_function(base_name, mtl_lib, specialized_name, func_consts);
}
MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) {
return nullptr;
}
auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
std::vector<NS::Object*> objs(funcs.size());
for (int i = 0; i < funcs.size(); i++) {
objs[i] = funcs[i];
}
NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size());
lfuncs->setPrivateFunctions(funcs_arr);
return lfuncs;
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
const auto& kname = hash_name.empty() ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
// Pull kernel from library
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
kernel_map_.insert({kname, kernel});
kernel_map_.insert({name, kernel});
return kernel;
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
// Look for cached kernel
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
}
Device& device(mlx::core::Device) {
static Device metal_device;
return metal_device;

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-24 Apple Inc.
// Copyright © 2023 Apple Inc.
#pragma once
@@ -31,9 +31,6 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
return mtllib_path;
}
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
class Device {
public:
Device();
@@ -62,73 +59,14 @@ class Device {
const std::function<std::string(const std::string&)>& lib_path_func =
get_colocated_mtllib_path);
MTL::Library* get_library(const std::string& name);
MTL::Library* get_library(
const std::string& name,
const std::string& source_string,
bool cache = true);
MTL::Library* get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache = true);
MTL::Function* get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::Function* get_function(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name = "",
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& hash_name = "",
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
const std::string& name,
const std::string& lib_name = "mlx");
MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
private:
MTL::Library* get_library_cache_(const std::string& name);
MTL::Library* get_library_(const std::string& source_string);
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
MTL::Function* get_function_(
const std::string& name,
const std::string& specialized_name,
const MTLFCList& func_consts,
MTL::Library* mtl_lib);
MTL::LinkedFunctions* get_linked_functions_(
const std::vector<MTL::Function*>& funcs);
MTL::ComputePipelineState* get_kernel_(
const std::string& name,
const MTL::Function* mtl_function);
MTL::ComputePipelineState* get_kernel_(
const std::string& name,
const MTL::Function* mtl_function,
const MTL::LinkedFunctions* linked_functions);
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
@@ -39,75 +39,114 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
if (idx_ndim <= 1) {
kname << "_" << idx_ndim;
}
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
size_t slice_size = 1;
for (auto s : slice_sizes_) {
slice_size *= s;
}
// Launch 2D grid of threads: indices x slice
size_t dim0 = out.size() / slice_size;
size_t dim1 = slice_size;
auto group_dims = get_block_dims(dim0, dim1, 1);
MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
size_t ndim = src.ndim();
size_t nthreads = out.size();
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
// Make the argument buffer to store the indices for the
// `Indices` struct in kernels/indexing.metal
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[0]->setIndex(0);
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[0]->setArrayLength(nidx);
// Shapes
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[1]->setIndex(nidx + 1);
// Strides
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[2]->setIndex(nidx + 2);
// Indices ndim
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
arg_descs[3]->setIndex(nidx + 3);
// Get the argument encoder
auto arg_enc = d.argument_encoder(arg_descs);
// Allocate and fill buffers for shapes and strides
int idx_ndim = nidx ? inputs[1].ndim() : 0;
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy(
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end(),
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
std::copy(
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end(),
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
}
// Allocate the argument buffer
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
// Register data with the encoder
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
for (int i = 0; i < nidx; ++i) {
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
}
if (idx_ndim > 0) {
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
MTL::ResourceUsageRead);
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
MTL::ResourceUsageRead);
}
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
// Set all the buffers
set_array_buffer(compute_encoder, src, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
// Set source info
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
// Set index info
//
// We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Launch grid
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Cleanup temporaries
arg_enc->release();
d.get_command_buffer(s.index)->addCompletedHandler(
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
allocator::free(arg_buf);
allocator::free(idx_shapes_buf);
allocator::free(idx_strides_buf);
});
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -142,28 +181,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Get kernel name
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
int idx_ndim = nidx ? inputs[1].ndim() : 0;
bool index_nd1_specialization = (idx_ndim == 1);
// Bail from fast path (1d index specialization) if scatter dims aren't
// the outermost dims and contiguous since update access won't be raster
// order.
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= (axes_[i] == i);
}
// Bail from fast path (1d index specialization) if any of the dims are
// broadcasted, since we can't rely on linear indexing in that case.
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= inputs[i].flags().row_contiguous;
}
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
kname << "scatter" << type_to_name(out) << idx_type_name;
switch (reduce_type_) {
case Scatter::None:
kname << "_none";
@@ -188,109 +206,126 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& upd = inputs.back();
size_t nthreads = upd.size();
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
// Set all the buffers
set_array_buffer(compute_encoder, upd, 1);
set_array_buffer(compute_encoder, out, 2);
// Make the argument buffer to store the indices for the
// `Indices` struct in kernels/indexing.metal
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[0]->setIndex(0);
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[0]->setArrayLength(nidx);
// Set update info
uint upd_ndim = upd.ndim();
// Shapes
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[1]->setIndex(nidx + 1);
// Strides
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
arg_descs[2]->setIndex(nidx + 2);
// Indices ndim
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
arg_descs[3]->setIndex(nidx + 3);
// Get the argument encoder
auto arg_enc = d.argument_encoder(arg_descs);
// Allocate and fill buffers for shapes and strides
int idx_ndim = nidx ? inputs[1].ndim() : 0;
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy(
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end(),
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
std::copy(
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end(),
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
}
// Allocate the argument buffer
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
// Register data with the encoder
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
for (int i = 0; i < nidx; ++i) {
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
}
if (idx_ndim > 0) {
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
MTL::ResourceUsageRead);
arg_enc->setBuffer(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
compute_encoder->useResource(
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
MTL::ResourceUsageRead);
}
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
size_t upd_ndim = upd.ndim();
size_t upd_size = 1;
for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i);
}
if (index_nd1_specialization) {
bool upd_col_contiguous = upd.flags().col_contiguous;
compute_encoder->setBytes(
out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
set_array_buffer(compute_encoder, upd, 1);
set_array_buffer(compute_encoder, out, 2);
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
compute_encoder->setBytes(
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
compute_encoder->setBytes(
out.strides().data(), out_ndim * sizeof(size_t), 8);
}
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
}
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
compute_encoder->setBytes(
out.strides().data(), out_ndim * sizeof(size_t), 8);
}
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Cleanup temporaries
arg_enc->release();
d.get_command_buffer(s.index)->addCompletedHandler(
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
allocator::free(arg_buf);
allocator::free(idx_shapes_buf);
allocator::free(idx_strides_buf);
});
}
} // namespace mlx::core

View File

@@ -3,13 +3,10 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
${CMAKE_CURRENT_SOURCE_DIR}/binary.h
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
)
@@ -25,14 +22,11 @@ set(
"quantized"
"random"
"reduce"
"rope"
"scan"
"softmax"
"sort"
"ternary"
"unary"
"gather"
"scatter"
"indexing"
)
function(build_kernel_base TARGET SRCFILE DEPS)
@@ -51,7 +45,11 @@ endfunction(build_kernel_base)
function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS}")
set(HEADERS_PADDED ${HEADERS})
if(${KERNEL} STREQUAL "conv")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/conv.h)
endif()
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS_PADDED}")
endfunction(build_kernel)
foreach(KERNEL ${KERNELS})

View File

@@ -11,6 +11,8 @@ template <typename U>
struct IndexValPair {
uint32_t index;
U val;
IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {}
};
template <typename U>
@@ -61,12 +63,24 @@ struct ArgMax {
}
};
bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>{
return IndexValPair<U>(
simd_shuffle_down(data.index, delta),
simd_shuffle_down(data.val, delta)
};
);
}
@@ -80,6 +94,7 @@ template <typename T, typename Op, int N_READS>
const device size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]],
threadgroup IndexValPair<T> *local_data [[threadgroup(0)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
@@ -108,9 +123,7 @@ template <typename T, typename Op, int N_READS>
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
IndexValPair<T> best{0, Op::init};
threadgroup IndexValPair<T> local_data[32];
IndexValPair<T> best(0, Op::init);
// Loop over the reduction axis in lsize*N_READS buckets
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
@@ -171,6 +184,7 @@ template <typename T, typename Op, int N_READS>
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
threadgroup IndexValPair<itype> *local_data [[threadgroup(0)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \

View File

@@ -1,231 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
if (metal::isnan(x) || metal::isnan(y)) {
return metal::numeric_limits<T>::quiet_NaN();
}
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf)
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::max(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x > y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x > y ? x : y;
}
};
struct Minimum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::min(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x < y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x < y ? x : y;
}
};
struct Multiply {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
};
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
};
};

View File

@@ -1,6 +1,155 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
struct Add {
template <typename T> T operator()(T x, T y) { return x + y; }
};
struct Divide {
template <typename T> T operator()(T x, T y) { return x / y; }
};
struct Remainder {
template <typename T> T operator()(T x, T y) { return x % y; }
template <> float operator()(float x, float y) { return fmod(x, y); }
template <> half operator()(half x, half y) { return fmod(x, y); }
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
};
struct Equal {
template <typename T> bool operator()(T x, T y) { return x == y; }
};
struct NaNEqual {
template <typename T> bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real)
&& metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
struct Greater {
template <typename T> bool operator()(T x, T y) { return x > y; }
};
struct GreaterEqual {
template <typename T> bool operator()(T x, T y) { return x >= y; }
};
struct Less {
template <typename T> bool operator()(T x, T y) { return x < y; }
};
struct LessEqual {
template <typename T> bool operator()(T x, T y) { return x <= y; }
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf) ? maxval :
(maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T> T operator()(T x, T y) { return metal::max(x, y); }
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x >= y ? x : y;
}
};
struct Minimum {
template <typename T> T operator()(T x, T y) { return metal::min(x, y); }
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x <= y ? x : y;
}
};
struct Multiply {
template <typename T> T operator()(T x, T y) { return x * y; }
};
struct NotEqual {
template <typename T> bool operator()(T x, T y) { return x != y; }
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T> T operator()(T x, T y) { return x - y; }
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) { return x && y; };
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) { return x || y; };
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_s2s(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
#include "mlx/backend/metal/kernels/binary.h"
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_ss(
@@ -240,4 +389,4 @@ instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)

View File

@@ -14,29 +14,10 @@ struct FloorDivide {
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <> complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
template <typename T> T operator()(T x, T y) { return x % y; }
template <> float operator()(float x, float y) { return fmod(x, y); }
template <> half operator()(half x, half y) { return fmod(x, y); }
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
};
template <typename T, typename U, typename Op1, typename Op2>

View File

@@ -1,7 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/binary.h"
#include "mlx/backend/metal/kernels/ternary.h"
#include "mlx/backend/metal/kernels/unary.h"
typedef half float16_t;

View File

@@ -121,11 +121,5 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
if (real != 0 && (real < 0 != b.real < 0)) {
real += b.real;
}
if (imag != 0 && (imag < 0 != b.imag < 0)) {
imag += b.imag;
}
return {real, imag};
}

View File

@@ -0,0 +1,481 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/conv_params.h"
#define MLX_MTL_CONST static constant constexpr const
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int vec_size,
int tgp_size,
int tgp_padding = 0>
struct Conv2DInputBlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = BM;
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
MLX_MTL_CONST int n_vecs = BK / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
MLX_MTL_CONST int n_rows = dst_fd / bstride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>& params;
int weight_h;
int weight_w;
int offsets_n[n_rows];
int offsets_oh[n_rows];
int offsets_ow[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoader(
const device T* src_,
threadgroup T* dst_,
const constant MLXConvParams<2>& params_,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bj),
params(params_),
weight_h(0),
weight_w(0) {
int out_n_pixels = params.oS[0] * params.oS[1];
for (int i = 0; i < n_rows; ++i) {
int offset_nhw = tid.y * BM + bi + i * bstride;
offsets_n[i] = offset_nhw / out_n_pixels;
int hw = offset_nhw % out_n_pixels;
offsets_oh[i] = hw / params.oS[1];
offsets_ow[i] = hw % params.oS[1];
}
(void)lid;
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
#pragma clang loop unroll(full)
for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
int n = offsets_n[i];
int oh = offsets_oh[i];
int ow = offsets_ow[i];
int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
// Read from input if in bounds
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
const device T* curr_src = src + n * params.in_strides[0] +
ih * params.in_strides[1] + iw * params.in_strides[2];
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = curr_src[j];
}
}
// Zero pad otherwise
else {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params.wS[1]) {
return;
}
weight_w = 0;
if (++weight_h < params.wS[0]) {
return;
}
weight_h = 0;
src += BK;
}
};
template <
typename T,
int BM,
int BN,
int BK,
int vec_size,
int tgp_size,
int tgp_padding = 0>
struct Conv2DWeightBlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = BN;
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
MLX_MTL_CONST int n_vecs = BK / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
MLX_MTL_CONST int n_rows = dst_fd / bstride;
// Leading dimension for src
const int src_ld;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>& params;
int weight_h;
int weight_w;
/* Constructor */
METAL_FUNC Conv2DWeightBlockLoader(
const device T* src_,
threadgroup T* dst_,
const constant MLXConvParams<2>& params_,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_.wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj),
params(params_),
weight_h(0),
weight_w(0) {
(void)lid;
(void)tid;
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
const device T* curr_src =
src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params.wS[1]) {
return;
}
weight_w = 0;
if (++weight_h < params.wS[0]) {
return;
}
weight_h = 0;
src += BK;
}
};
///////////////////////////////////////////////////////////////////////////////
// Transforms
///////////////////////////////////////////////////////////////////////////////
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
int tgp_padding_a = 0,
int tgp_padding_b = 0,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct Conv2DBlockMMA {
// Warp tile size along M
MLX_MTL_CONST int TM = BM / (WM * 8);
// Warp tile size along N
MLX_MTL_CONST int TN = BN / (WN * 8);
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TN_stride = 8 * WN;
// Leading dimensions of threadgroup A, B blocks
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
// Strides of A, B along reduction axis
MLX_MTL_CONST short simd_stride_a =
transpose_a ? TM_stride : TM_stride * lda_tgp;
MLX_MTL_CONST short simd_stride_b =
transpose_b ? TN_stride * ldb_tgp : TN_stride;
// Jump between elements
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
// Offsets within threadgroup
const int tm;
const int tn;
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
short sm;
short sn;
/* Constructor */
METAL_FUNC Conv2DBlockMMA(
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Iterate over BK in blocks of 8
#pragma clang loop unroll(full)
for (short kk = 0; kk < BK; kk += 8) {
short2 offset_a =
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
short2 offset_b =
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
As__ += simd_stride_a;
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
Bs__ += simd_stride_b;
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
simdgroup_multiply_accumulate(
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device T* C, const int ldc) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
METAL_FUNC void
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
}
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct Conv2DImplicitGEMMKernel {
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
MLX_MTL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
MLX_MTL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
MLX_MTL_CONST short tgp_size = WM * WN * 32;
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
using loader_a_t =
Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
using loader_b_t =
Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
using mma_t = Conv2DBlockMMA<
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
tgp_padding_a,
tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const int K = params.wt_strides[0];
const int N = params.O;
B += c_col * K;
C += c_row * N + c_col;
// Prepare threadgroup memory for loading
threadgroup T* As = tgp_memory;
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
mma_op.store_result(C, N);
}
};

View File

@@ -1,102 +1,16 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/conv_params.h"
#include "mlx/backend/metal/kernels/bf16.h"
#define MLX_MTL_CONST static constant constexpr const
#include "mlx/backend/metal/kernels/conv.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
/// Naive unfold with dilation
///////////////////////////////////////////////////////////////////////////////
template <typename T, int N>
[[kernel]] void naive_unfold_Nd(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C;
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
int out_pixels = 1;
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
// Set out
out += gid.z * filter_size + gid.y * (params->C);
// Corrdinates in input
int is[N] = {0};
// gid.z: N oS (Batch and row in unfolded output)
// gid.y: wS (Filter location to unfold input)
// gid.x: C (channel)
int n = (gid.z) / out_pixels;
int oS = (gid.z) % out_pixels;
int wS = gid.y;
bool valid = n < params->N;
// Unroll dimensions
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
int is_max = 1 + params->idil[i] * (params->iS[i] - 1);
valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);
is[i] = is_ / params->idil[i];
oS /= params->oS[i];
wS /= params->wS[i];
}
if(valid) {
size_t in_offset = n * params->in_strides[0];
for(int i = 0; i < N; ++i) {
in_offset += is[i] * params->in_strides[i + 1];
}
out[gid.x] = in[in_offset + gid.x];
} else {
out[gid.x] = T(0);
}
}
#define instantiate_naive_unfold_nd(name, itype, n) \
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
[[kernel]] void naive_unfold_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]);
#define instantiate_naive_unfold_nd_dims(name, itype) \
instantiate_naive_unfold_nd(name, itype, 1) \
instantiate_naive_unfold_nd(name, itype, 2) \
instantiate_naive_unfold_nd(name, itype, 3)
instantiate_naive_unfold_nd_dims(float32, float);
instantiate_naive_unfold_nd_dims(float16, half);
instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Slow and naive conv2d kernels
/// Slow and naive kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
@@ -144,8 +58,8 @@ template <typename T,
// Local in
for(int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.dil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.dil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
@@ -202,6 +116,59 @@ instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Implicit gemm kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemm_kernel = Conv2DImplicitGEMMKernel<T, BM, BN, BK, WM, WN, /*transpose_a*/ false, /*transpose_b*/ true>;
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
gemm_kernel::run(
in, wt, out,
params, tgp_memory,
tid, lid, simd_gid, simd_lid
);
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_conv_2d(name, itype, 32, 32, 32, 2, 2) \
instantiate_implicit_conv_2d(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_conv_2d(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels
///////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,19 @@
// Copyright © 2023 Apple Inc.
#pragma once
template <int NDIM>
struct MLXConvParams {
const int N; // Batch size
const int C; // In channels
const int O; // Out channels
const int iS[NDIM]; // Input spatial dim
const int wS[NDIM]; // Weight spatial dim
const int oS[NDIM]; // Output spatial dim
const int str[NDIM]; // Kernel strides
const int pad[NDIM]; // Input padding
const int dil[NDIM]; // Kernel dilation
const size_t in_strides[NDIM + 2]; // In strides
const size_t wt_strides[NDIM + 2]; // Wt strides
const size_t out_strides[NDIM + 2]; // Out strides
};

View File

@@ -1,187 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_atomic>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Gather kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T *src [[buffer(0)]],
device T *out [[buffer(1)]],
const constant int *src_shape [[buffer(2)]],
const constant size_t *src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int *slice_sizes [[buffer(5)]],
const constant int *axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x;
auto ind_offset = index.y;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = ind_idx * indices.strides[indices.ndim * i];
} else {
idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(
ind_offset, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx];
}
#define make_gather_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
[[kernel]] void gather( \
const device T *src [[buffer(0)]], \
device T *out [[buffer(1)]], \
const constant int *src_shape [[buffer(2)]], \
const constant size_t *src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int *slice_sizes [[buffer(5)]], \
const constant int *axes [[buffer(6)]], \
const constant int *idx_shapes [[buffer(7)]], \
const constant size_t *idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(IdxT) \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]) { \
\
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, \
idx_shapes, \
idx_strides, \
idx_ndim}; \
\
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
src, \
out, \
src_shape, \
src_strides, \
src_ndim, \
slice_sizes, \
axes, \
idxs, \
index, \
grid_dim); \
}
#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
make_gather(0)
make_gather(1)
make_gather(2)
make_gather(3)
make_gather(4)
make_gather(5)
make_gather(6)
make_gather(7)
make_gather(8)
make_gather(9)
make_gather(10)
/////////////////////////////////////////////////////////////////////
// Gather instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
[[kernel]] void gather<src_t, idx_t, nidx, nd>( \
const device src_t *src [[buffer(0)]], \
device src_t *out [[buffer(1)]], \
const constant int *src_shape [[buffer(2)]], \
const constant size_t *src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int *slice_sizes [[buffer(5)]], \
const constant int *axes [[buffer(6)]], \
const constant int *idx_shapes [[buffer(7)]], \
const constant size_t *idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(idx_t) \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]);
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name)
#define instantiate_gather4(name, src_t, idx_t, nidx) \
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
instantiate_gather5(name, src_t, idx_t, nidx, 2, )
// Special for case NIDX=0
instantiate_gather4("bool_", bool, bool, 0)
instantiate_gather4("uint8", uint8_t, bool, 0)
instantiate_gather4("uint16", uint16_t, bool, 0)
instantiate_gather4("uint32", uint32_t, bool, 0)
instantiate_gather4("uint64", uint64_t, bool, 0)
instantiate_gather4("int8", int8_t, bool, 0)
instantiate_gather4("int16", int16_t, bool, 0)
instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10)
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
instantiate_gather(uint8, uint8_t)
instantiate_gather(uint16, uint16_t)
instantiate_gather(uint32, uint32_t)
instantiate_gather(uint64, uint64_t)
instantiate_gather(int8, int8_t)
instantiate_gather(int16, int16_t)
instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t)

View File

@@ -121,18 +121,8 @@ struct GEMVKernel {
for(int tm = 0; tm < TM; tm++) {
// Load for the row
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
inter[tn] = mat[tm * in_vec_size + col_idx];
}
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn];
}
// Accumulate results

View File

@@ -1,54 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Indexing utils
/////////////////////////////////////////////////////////////////////
template <typename IdxT, int NIDX>
struct Indices {
const array<const device IdxT*, NIDX> buffers;
const constant int* shapes;
const constant size_t* strides;
const int ndim;
};
template <typename IdxT>
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
if (is_unsigned_v<IdxT>) {
return idx;
} else {
return (idx < 0) ? idx + size : idx;
}
}
#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
#define IDX_ARG_0(idx_t)
#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
#define IDX_ARR_N(n) idx##n,
#define IDX_ARR_0()
#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)

View File

@@ -0,0 +1,254 @@
// Copyright © 2023 Apple Inc.
#include <metal_atomic>
#include <metal_texture>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/reduce.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Gather kernel
/////////////////////////////////////////////////////////////////////
template <typename IdxT, int NIDX>
struct Indices {
const array<device IdxT*, NIDX> buffers [[id(0)]];
device int* shapes [[id(NIDX + 1)]];
device size_t* strides [[id(NIDX + 2)]];
const int ndim [[id(NIDX + 3)]];
};
template <typename IdxT>
inline size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
template <>
inline size_t offset_neg_idx(bool idx, size_t) {
return idx;
}
template <>
inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
}
template <typename T, typename IdxT, int NIDX>
[[kernel]] void gather(
const device T *src [[buffer(0)]],
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
device T *out [[buffer(2)]],
const device int *src_shape [[buffer(3)]],
const device size_t *src_strides [[buffer(4)]],
const device size_t& src_ndim [[buffer(5)]],
const device int *slice_sizes [[buffer(6)]],
const device size_t& slice_size [[buffer(7)]],
const device int *axes [[buffer(8)]],
uint gid [[thread_position_in_grid]]) {
auto ind_idx = gid / slice_size;
auto ind_offset = gid % slice_size;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(
ind_offset, slice_sizes, src_strides, src_ndim);
out[gid] = src[src_idx + src_offset];
}
#define instantiate_gather4(name, src_type, ind_type, nindex) \
template [[host_name("gather" name "_" #nindex)]] \
[[kernel]] void gather( \
const device src_type *src [[buffer(0)]], \
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
device src_type *out [[buffer(2)]], \
const device int *src_shape [[buffer(3)]], \
const device size_t *src_strides [[buffer(4)]], \
const device size_t& src_ndim [[buffer(5)]], \
const device int *slice_sizes [[buffer(6)]], \
const device size_t& slice_size [[buffer(7)]], \
const device int* axes [[buffer(8)]], \
uint gid [[thread_position_in_grid]]);
// Special for case NIDX=0
instantiate_gather4("bool_", bool, bool, 0)
instantiate_gather4("uint8", uint8_t, bool, 0)
instantiate_gather4("uint16", uint16_t, bool, 0)
instantiate_gather4("uint32", uint32_t, bool, 0)
instantiate_gather4("uint64", uint64_t, bool, 0)
instantiate_gather4("int8", int8_t, bool, 0)
instantiate_gather4("int16", int16_t, bool, 0)
instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10)
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
instantiate_gather(uint8, uint8_t)
instantiate_gather(uint16, uint16_t)
instantiate_gather(uint32, uint32_t)
instantiate_gather(uint64, uint64_t)
instantiate_gather(int8, int8_t)
instantiate_gather(int16, int16_t)
instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t)
/////////////////////////////////////////////////////////////////////
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX>
[[kernel]] void scatter(
const device Indices<IdxT, NIDX>& indices [[buffer(0)]],
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const device int *upd_shape [[buffer(3)]],
const device size_t *upd_strides [[buffer(4)]],
const device size_t& upd_ndim [[buffer(5)]],
const device size_t& upd_size [[buffer(6)]],
const device int *out_shape [[buffer(7)]],
const device size_t *out_strides [[buffer(8)]],
const device size_t& out_ndim [[buffer(9)]],
const device int* axes [[buffer(10)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid / upd_size;
auto ind_offset = gid % upd_size;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
}
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
template [[host_name("scatter" name "_" #nindex)]] \
[[kernel]] void scatter<type, ind_type, op_type, nindex>( \
const device Indices<ind_type, nindex>& indices [[buffer(0)]], \
const device type *updates [[buffer(1)]], \
device mlx_atomic<type> *out [[buffer(2)]], \
const device int *upd_shape [[buffer(3)]], \
const device size_t *upd_strides [[buffer(4)]], \
const device size_t& upd_ndim [[buffer(5)]], \
const device size_t& upd_size [[buffer(6)]], \
const device int *out_shape [[buffer(7)]], \
const device size_t *out_strides [[buffer(8)]], \
const device size_t& out_ndim [[buffer(9)]], \
const device int* axes [[buffer(10)]], \
uint gid [[thread_position_in_grid]]);
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
#define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10)
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
instantiate_scatter2(#name "uint16", type, uint16_t) \
instantiate_scatter2(#name "uint32", type, uint32_t) \
instantiate_scatter2(#name "uint64", type, uint64_t) \
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t)
// TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t)
instantiate_scatter_nd0(uint16, uint16_t)
instantiate_scatter_nd0(uint32, uint32_t)
instantiate_scatter_nd0(int8, int8_t)
instantiate_scatter_nd0(int16, int16_t)
instantiate_scatter_nd0(int32, int32_t)
instantiate_scatter_nd0(float16, half)
instantiate_scatter_nd0(float32, float)
instantiate_scatter_nd0(bfloat16, bfloat16_t)
instantiate_scatter(bool_, bool)
instantiate_scatter(uint8, uint8_t)
instantiate_scatter(uint16, uint16_t)
instantiate_scatter(uint32, uint32_t)
instantiate_scatter(int8, int8_t)
instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half)
instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t)

View File

@@ -15,14 +15,6 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32;
template <typename T> struct AccT {
typedef T acc_t;
};
template <> struct AccT<bfloat16_t> {
typedef float acc_t;
};
template <typename T, const int BM, const int BN, const int group_size, const int bits>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
@@ -39,23 +31,21 @@ template <typename T, const int BM, const int BN, const int group_size, const in
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
(void)lid;
constexpr int bitmask = (1 << bits) - 1;
constexpr int el_per_thread = 32 / bits;
constexpr int colgroup = BN * el_per_thread;
constexpr int groups_per_block = colgroup / group_size;
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
typedef typename AccT<T>::acc_t U;
threadgroup U scales_block[BM * groups_per_block];
threadgroup U biases_block[BM * groups_per_block];
threadgroup U x_block[colgroup];
threadgroup T scales_block[BM * groups_per_block];
threadgroup T biases_block[BM * groups_per_block];
threadgroup T x_block[colgroup];
thread uint32_t w_local;
thread U result = 0;
thread U scale = 1;
thread U bias = 0;
thread U x_thread[el_per_thread];
thread T result = 0;
thread T scale = 1;
thread T bias = 0;
thread T x_thread[el_per_thread];
// Adjust positions
const int in_vec_size_w = in_vec_size / el_per_thread;
@@ -67,19 +57,12 @@ template <typename T, const int BM, const int BN, const int group_size, const in
x += tid.z * in_vec_size;
y += tid.z * out_vec_size;
if (out_row >= out_vec_size) {
return;
}
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=colgroup) {
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<el_per_thread; j++) {
x_block[simd_lid * el_per_thread + j] = x[i + simd_lid * el_per_thread + j];
}
if (simd_gid < simdgroups_fetching_vec) {
x_block[lid] = x[lid + i];
}
if (simd_lid == 0) {
#pragma clang loop unroll(full)
@@ -107,7 +90,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Do all the work.
#pragma clang loop unroll(full)
for (int k=0; k<el_per_thread; k++) {
result += (scale * static_cast<U>(w_local & bitmask) + bias) * x_thread[k];
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
w_local >>= bits;
}
}
@@ -117,7 +100,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Store the result
if (simd_lid == 0) {
y[out_row] = static_cast<T>(result);
y[out_row] = result;
}
}
@@ -146,25 +129,23 @@ template <typename T, const int BM, const int BN, const int group_size, const in
constexpr int colgroup = BN * el_per_int;
constexpr int groups_per_block = colgroup / group_size;
typedef typename AccT<T>::acc_t U;
threadgroup U scales_block[BM * groups_per_block];
threadgroup U biases_block[BM * groups_per_block];
threadgroup U x_block[BM];
threadgroup T scales_block[BM * groups_per_block];
threadgroup T biases_block[BM * groups_per_block];
threadgroup T x_block[BM];
thread uint32_t w_local;
thread U result[el_per_int] = {0};
thread U scale = 1;
thread U bias = 0;
thread U x_local = 0;
thread T result[el_per_int] = {0};
thread T scale = 1;
thread T bias = 0;
thread T x_local = 0;
// Adjust positions
const int out_vec_size_w = out_vec_size / el_per_int;
const int out_vec_size_g = out_vec_size / group_size;
int out_col_start = tid.y * (BN * el_per_int);
int out_col = out_col_start + simd_gid * el_per_int;
int out_col = (tid.y * BN + simd_gid) * el_per_int;
w += out_col / el_per_int;
scales += out_col_start / group_size;
biases += out_col_start / group_size;
scales += out_col / group_size;
biases += out_col / group_size;
x += tid.z * in_vec_size;
y += tid.z * out_vec_size + out_col;
@@ -174,22 +155,26 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=BM) {
int offset_lid = simd_lid + i;
int offset_gid = simd_gid + i;
bool thread_in_bounds = offset_lid < in_vec_size;
bool group_in_bounds = offset_gid < in_vec_size;
int offset = simd_lid + i;
bool thread_in_bounds = offset < in_vec_size;
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
x_block[simd_lid] = (thread_in_bounds) ? x[offset] : 0;
}
// Load the scales and biases to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_lid < groups_per_block && group_in_bounds) {
scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
if (simd_gid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j];
}
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -199,12 +184,12 @@ template <typename T, const int BM, const int BN, const int group_size, const in
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
// Load the matrix elements
w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
w_local = (thread_in_bounds) ? w[offset * out_vec_size_w] : 0;
// Do all the work.
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local;
w_local >>= bits;
}
}
@@ -219,7 +204,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
y[k] = static_cast<T>(result[k]);
y[k] = result[k];
}
}
}
@@ -258,6 +243,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
threadgroup T scales_block[BN * groups_per_block];
threadgroup T biases_block[BN * groups_per_block];
threadgroup T Xs[BM * BK];
@@ -320,7 +306,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
if (y_row + offset_row < N) {
if (y_col + offset_col < N) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
@@ -435,9 +421,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
for (int k=0; k<K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the x tile
short num_k = min(BK, K - k);
if (num_els < BM || num_k < BK) {
loader_x.load_safe(short2(num_k, num_els));
if (num_els < BM) {
loader_x.load_safe(short2(BK, num_els));
} else {
loader_x.load_unsafe();
}
@@ -465,7 +450,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Load the w tile
{
if (num_k < BK) {
if (k + BK >= K) {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
@@ -558,9 +543,6 @@ instantiate_qmv_types(128, 8)
instantiate_qmv_types( 64, 2)
instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8)
instantiate_qmv_types( 32, 2)
instantiate_qmv_types( 32, 4)
instantiate_qmv_types( 32, 8)
#define instantiate_qvm(name, itype, group_size, bits) \
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
@@ -588,9 +570,6 @@ instantiate_qvm_types(128, 8)
instantiate_qvm_types( 64, 2)
instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8)
instantiate_qvm_types( 32, 2)
instantiate_qvm_types( 32, 4)
instantiate_qvm_types( 32, 8)
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
@@ -622,9 +601,6 @@ instantiate_qmm_t_types(128, 8)
instantiate_qmm_t_types( 64, 2)
instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8)
instantiate_qmm_t_types( 32, 2)
instantiate_qmm_t_types( 32, 4)
instantiate_qmm_t_types( 32, 8)
#define instantiate_qmm_n(name, itype, group_size, bits) \
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
@@ -653,6 +629,3 @@ instantiate_qmm_n_types(128, 8)
instantiate_qmm_n_types( 64, 2)
instantiate_qmm_n_types( 64, 4)
instantiate_qmm_n_types( 64, 8)
instantiate_qmm_n_types( 32, 2)
instantiate_qmm_n_types( 32, 4)
instantiate_qmm_n_types( 32, 8)

View File

@@ -24,59 +24,11 @@ template <typename T, typename Op>
device otype *out [[buffer(1)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// All reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
inline U per_thread_all_reduce(
const device T *in,
const device size_t& in_size,
uint gid,
uint grid_size) {
Op op;
U total_val = Op::init;
if (gid * N_READS < in_size) {
in += gid * N_READS;
int r = 0;
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for(int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for(int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for(int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
}
return total_val;
}
// NB: This kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T *in [[buffer(0)]],
@@ -88,18 +40,53 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// NB: this kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
Op op;
threadgroup U local_vals[simd_size];
U total_val = Op::init;
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
in += gid * N_READS;
int r = 0;
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for(int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for(int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for(int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
// Reduction within simd group
total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
// Reduction within thread group
threadgroup_barrier(mem_flags::mem_threadgroup);
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
@@ -111,46 +98,6 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Write simd group reduction results to local memory
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction of simdgroup reduction results within threadgroup.
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Reduction across threadgroups
if (lid == 0) {
out[thread_group_id] = total_val;
}
}
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \
@@ -164,80 +111,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] \
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Row atomics
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
inline U per_thread_row_reduce(
const device T *in,
const constant size_t& reduction_size,
const constant size_t& out_size,
const constant int* shape,
const constant size_t* strides,
const constant int& ndim,
uint lsize_x,
uint lid_x,
uint2 tid) {
Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid_x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize_x * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
return total_val;
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general(
const device T *in [[buffer(0)]],
@@ -255,9 +133,46 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid.x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize.x * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
total_val = op.simd_reduce(total_val);
@@ -279,53 +194,6 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
// Reduction within simd group - simd_add isn't supported for int64 types
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if thread group has multiple simd groups
if(ceildiv(reduction_size, N_READS) > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
}
// Write row reduce output for threadgroup with 1st thread in thread group
if (lid.x == 0) {
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
}
}
#define instantiate_row_reduce_general(name, itype, otype, op) \
template [[host_name("row_reduce_general_" #name)]] \
[[kernel]] void row_reduce_general<itype, otype, op>( \
@@ -343,59 +211,52 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Column reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
inline U _contiguous_strided_reduce(
const device T *in,
threadgroup U *local_data,
uint in_idx,
uint reduction_size,
uint reduction_stride,
uint2 tid,
uint2 lid,
inline void _contiguous_strided_reduce(
const device T *in,
device mlx_atomic<U> *out,
threadgroup U *local_data,
uint in_idx,
uint out_idx,
uint reduction_size,
uint reduction_stride,
uint2 tid,
uint2 lid,
uint2 lsize) {
Op op;
U total_val = Op::init;
T local_vals[N_READS];
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
uint offset = base_offset + r;
total_val = op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
for(uint r = 0; r < N_READS; r++) {
uint offset = base_offset + r;
offset = offset < reduction_size ? offset : reduction_size - 1;
local_vals[r] = in[in_idx + offset * reduction_stride];
}
local_data[lsize.y * lid.x + lid.y] = total_val;
U total_val = Op::init;
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
total_val = op(static_cast<U>(total_val), local_vals[r]);
}
local_data[lsize.y * lid.x + lid.y] = total_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
U val = Op::init;
if(lid.y == 0) {
// Perform reduction across columns in thread group
for(uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]);
}
}
U val = op.init;
return val;
for(uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]);
}
op.atomic_update(out, val, out_idx);
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
@@ -404,13 +265,13 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
@@ -420,66 +281,18 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
ndim
);
Op op;
if(out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
op.atomic_update(out, val, out_idx);
}
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
if(out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
uint tgsize_y = ceildiv(gsize.y, lsize.y);
uint tgsize_z = ceildiv(gsize.z, lsize.z);
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
}
_contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
in_idx,
out_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
}
}
@@ -499,23 +312,6 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
@@ -526,15 +322,6 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_general(name, itype, otype, op)
#define instantiate_reduce_no_atomics(name, itype, otype, op) \
instantiate_all_reduce_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_col_reduce_general_no_atomics(name, itype, otype, op)
#define instantiate_same_reduce_no_atomics(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce_no_atomics(name ##tname, type, type, op<type>)
#define instantiate_same_reduce(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce(name ##tname, type, type, op<type>)
@@ -566,9 +353,6 @@ instantiate_same_reduce(sum, int32, int32_t, Sum)
instantiate_same_reduce(sum, float16, half, Sum)
instantiate_same_reduce(sum, float32, float, Sum)
instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum)
instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum)
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
@@ -578,9 +362,6 @@ instantiate_same_reduce(prod, int32, int32_t, Prod)
instantiate_same_reduce(prod, float16, half, Prod)
instantiate_same_reduce(prod, float32, float, Prod)
instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod)
instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod)
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
@@ -600,9 +381,6 @@ instantiate_same_reduce(min_, int32, int32_t, Min)
instantiate_same_reduce(min_, float16, half, Min)
instantiate_same_reduce(min_, float32, float, Min)
instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min)
instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min)
instantiate_same_reduce(max_, uint8, uint8_t, Max)
instantiate_same_reduce(max_, uint16, uint16_t, Max)
instantiate_same_reduce(max_, uint32, uint32_t, Max)
@@ -612,8 +390,5 @@ instantiate_same_reduce(max_, int32, int32_t, Max)
instantiate_same_reduce(max_, float16, half, Max)
instantiate_same_reduce(max_, float32, float, Max)
instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max)
instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max)
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)

View File

@@ -1,68 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional>
[[kernel]] void rope(
const device T *in [[buffer(0)]],
device T * out [[buffer(1)]],
constant const size_t strides[3],
constant const int& offset,
constant const float& base,
constant const float& scale,
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Compute the input and output indices
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
out_index_2 = out_index_1 + grid.x;
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
}
// Figure out L and d.
float L = scale * static_cast<float>(pos.y + offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float rx1 = x1 * costheta - x2 * sintheta;
float rx2 = x1 * sintheta + x2 * costheta;
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
}
#define instantiate_rope(name, type, traditional) \
template [[host_name("rope_" #name)]] \
[[kernel]] void rope<type, traditional>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const size_t strides[3], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
instantiate_rope(traditional_float16, half, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
instantiate_rope(traditional_float32, float, true)
instantiate_rope(float16, half, false)
instantiate_rope(bfloat16, bfloat16_t, false)
instantiate_rope(float32, float, false)

View File

@@ -1,265 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_atomic>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/reduce.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX> \
METAL_FUNC void scatter_1d_index_impl(
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const constant bool& upd_col_contiguous [[buffer(6)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
Op op;
uint out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(
idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];
}
if (!upd_col_contiguous) {
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
} else {
op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x);
}
}
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter_1d_index( \
const device T *updates [[buffer(1)]], \
device mlx_atomic<T> *out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
const constant bool& upd_col_contiguous [[buffer(6)]], \
IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \
\
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
\
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
out_shape, \
out_strides, \
upd_size, \
upd_col_contiguous, \
idx_buffers, \
gid); \
\
}
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_impl(
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const constant int *upd_shape [[buffer(3)]],
const constant size_t *upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int *out_shape [[buffer(7)]],
const constant size_t *out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid.y;
auto ind_offset = gid.x;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
if (upd_size > 1) {
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
out_idx += out_offset;
}
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx);
}
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter( \
const device T *updates [[buffer(1)]], \
device mlx_atomic<T> *out [[buffer(2)]], \
const constant int *upd_shape [[buffer(3)]], \
const constant size_t *upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int *out_shape [[buffer(7)]], \
const constant size_t *out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int *idx_shapes [[buffer(11)]], \
const constant size_t *idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \
\
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, \
idx_shapes, \
idx_strides, \
idx_ndim}; \
\
return scatter_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
upd_shape, \
upd_strides, \
upd_ndim, \
upd_size, \
out_shape, \
out_strides, \
out_ndim, \
axes, \
idxs, \
gid); \
}
#define make_scatter(n) \
make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) \
make_scatter_1d_index(IDX_ARG_ ##n, IDX_ARR_ ##n)
make_scatter(0)
make_scatter(1)
make_scatter(2)
make_scatter(3)
make_scatter(4)
make_scatter(5)
make_scatter(6)
make_scatter(7)
make_scatter(8)
make_scatter(9)
make_scatter(10)
/////////////////////////////////////////////////////////////////////
// Scatter instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter" name "_" #nidx)]] \
[[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \
const device src_t *updates [[buffer(1)]], \
device mlx_atomic<src_t> *out [[buffer(2)]], \
const constant int *upd_shape [[buffer(3)]], \
const constant size_t *upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int *out_shape [[buffer(7)]], \
const constant size_t *out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int *idx_shapes [[buffer(11)]], \
const constant size_t *idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter_1d_index" name "_" #nidx)]] \
[[kernel]] void scatter_1d_index<src_t, idx_t, op_t, nidx>( \
const device src_t *updates [[buffer(1)]], \
device mlx_atomic<src_t> *out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
const constant bool& upd_col_contiguous [[buffer(6)]], \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
#define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10)
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
instantiate_scatter2(#name "uint16", type, uint16_t) \
instantiate_scatter2(#name "uint32", type, uint32_t) \
instantiate_scatter2(#name "uint64", type, uint64_t) \
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t)
// TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t)
instantiate_scatter_nd0(uint16, uint16_t)
instantiate_scatter_nd0(uint32, uint32_t)
instantiate_scatter_nd0(int8, int8_t)
instantiate_scatter_nd0(int16, int16_t)
instantiate_scatter_nd0(int32, int32_t)
instantiate_scatter_nd0(float16, half)
instantiate_scatter_nd0(float32, float)
instantiate_scatter_nd0(bfloat16, bfloat16_t)
instantiate_scatter(bool_, bool)
instantiate_scatter(uint8, uint8_t)
instantiate_scatter(uint16, uint16_t)
instantiate_scatter(uint32, uint32_t)
instantiate_scatter(int8, int8_t)
instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half)
instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t)

View File

@@ -1,11 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/loader.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
using namespace metal;
using namespace mlx::steel;

View File

@@ -1,189 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
int N_CHANNELS = 0,
bool SMALL_FILTER = false>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using namespace mlx::steel;
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DInputBlockLoaderSmallChannels<
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>,
// Else go to general loader
typename metal::conditional_t<
// Check if filter size is small enough
SMALL_FILTER,
// Go to small filter specialization
Conv2DInputBlockLoaderSmallFilter<
T, BM, BN, BK, tgp_size, tgp_padding_a>,
// Else go to large filter generalization
Conv2DInputBlockLoaderLargeFilter<
T, BM, BN, BK, tgp_size, tgp_padding_a>
>
>;
// Weight loader
using loader_b_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DWeightBlockLoaderSmallChannels<
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>,
// Else go to general loader
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>
>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
const int N = gemm_params->N;
B += c_col * K;
C += c_row * N + c_col;
const int2 offsets_a(0, c_row);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations = gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
short tgp_bm = min(BM, gemm_params->M - c_row);
short tgp_bn = min(BN, gemm_params->N - c_col);
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);

Some files were not shown because too many files have changed in this diff Show More