mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 14:34:37 +08:00
Compare commits
115 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
9f9cb7a2ef | ||
![]() |
7e26fd8032 | ||
![]() |
eab2685c67 | ||
![]() |
50dfb664db | ||
![]() |
0189ab6ab6 | ||
![]() |
9401507336 | ||
![]() |
eb8321d863 | ||
![]() |
79ef49b2c2 | ||
![]() |
e110ca11e2 | ||
![]() |
226748b3e7 | ||
![]() |
d568c7ee36 | ||
![]() |
e6fecbb3e1 | ||
![]() |
da83f899bb | ||
![]() |
7e5674d8be | ||
![]() |
0a558577bf | ||
![]() |
fb71a82ada | ||
![]() |
23406c9e9e | ||
![]() |
b3ec792380 | ||
![]() |
6a9b584f3d | ||
![]() |
81dd33af66 | ||
![]() |
8b76571896 | ||
![]() |
e78a6518fa | ||
![]() |
1873ffda01 | ||
![]() |
c417e42116 | ||
![]() |
358e1fd6ab | ||
![]() |
631dfbe673 | ||
![]() |
56a4eaed72 | ||
![]() |
bf925d9dc7 | ||
![]() |
1a7ed5dcb6 | ||
![]() |
5be5daa6ef | ||
![]() |
60cb11764e | ||
![]() |
cbd5445ea7 | ||
![]() |
2c7e9b5158 | ||
![]() |
2263e4b279 | ||
![]() |
863039da4c | ||
![]() |
7178ac0111 | ||
![]() |
e7f9710499 | ||
![]() |
ff4223904d | ||
![]() |
a9f80d60f6 | ||
![]() |
2e158cf6d0 | ||
![]() |
8bd6bfa4b5 | ||
![]() |
8b1906abd0 | ||
![]() |
06375e6605 | ||
![]() |
b21242faf1 | ||
![]() |
cc05a281c4 | ||
![]() |
fe96ceee66 | ||
![]() |
9814a2ae12 | ||
![]() |
6992498e7a | ||
![]() |
21623156a3 | ||
![]() |
79c859e2e0 | ||
![]() |
b00ac960b4 | ||
![]() |
02a9fc7bfa | ||
![]() |
f390957685 | ||
![]() |
17f57df797 | ||
![]() |
7f7b9662ea | ||
![]() |
19bef39f5c | ||
![]() |
a30e7ed2da | ||
![]() |
8db7161c94 | ||
![]() |
09f1777896 | ||
![]() |
490c0c4fdc | ||
![]() |
c4a471c99d | ||
![]() |
86f495985b | ||
![]() |
67d1894759 | ||
![]() |
5bfe89bdb1 | ||
![]() |
82463e9938 | ||
![]() |
771575d27b | ||
![]() |
20a01bbd9f | ||
![]() |
ec8578d41a | ||
![]() |
d0dbfe0b97 | ||
![]() |
3d405fb3b1 | ||
![]() |
b0012cdd0f | ||
![]() |
84d61d27aa | ||
![]() |
ed83908931 | ||
![]() |
ef5f7d1aea | ||
![]() |
090ff659dc | ||
![]() |
85c8a91a27 | ||
![]() |
581b699ac9 | ||
![]() |
8a0677d56d | ||
![]() |
b18468bf81 | ||
![]() |
107ba2891a | ||
![]() |
cd9e184529 | ||
![]() |
2e7c02d5cd | ||
![]() |
ae18326533 | ||
![]() |
91eba8e485 | ||
![]() |
d07e295c62 | ||
![]() |
dce4bd74a4 | ||
![]() |
ffff671273 | ||
![]() |
12d4507ee3 | ||
![]() |
8580d997ff | ||
![]() |
061cf9a4ce | ||
![]() |
99abb9eff4 | ||
![]() |
fffe072028 | ||
![]() |
a1a31eed27 | ||
![]() |
ae812350f9 | ||
![]() |
b63ef10a7f | ||
![]() |
42afe27e12 | ||
![]() |
76e63212ff | ||
![]() |
aac2f9fb61 | ||
![]() |
bddf23f175 | ||
![]() |
039da779d1 | ||
![]() |
d88d2124b5 | ||
![]() |
e142aaf8a1 | ||
![]() |
0caf35f4b8 | ||
![]() |
3fc993f82d | ||
![]() |
741eb28443 | ||
![]() |
1a87dc5ea8 | ||
![]() |
2427fa171e | ||
![]() |
639e06e1f3 | ||
![]() |
02fedbf1da | ||
![]() |
110d9b149d | ||
![]() |
9cbff5ec1d | ||
![]() |
433c0206b0 | ||
![]() |
8915901966 | ||
![]() |
f48bc496c7 | ||
![]() |
913b19329c |
@@ -31,7 +31,7 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
@@ -49,11 +49,6 @@ jobs:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
# TODO: Reenable when extension api becomes stable
|
||||
# - run:
|
||||
# name: Build example extension
|
||||
# command: |
|
||||
# cd examples/extensions && python3 -m pip install .
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
@@ -69,18 +64,19 @@ jobs:
|
||||
default: "15.2.0"
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.large.gen1
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.8
|
||||
brew install openmpi
|
||||
python3.8 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
@@ -101,11 +97,14 @@ jobs:
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
# TODO: Reenable when extension api becomes stable
|
||||
# - run:
|
||||
# name: Build example extension
|
||||
# command: |
|
||||
# cd examples/extensions && python3.11 -m pip install .
|
||||
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd examples/extensions
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext -j8
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
@@ -117,7 +116,13 @@ jobs:
|
||||
name: Run CPP tests
|
||||
command: |
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
DEVICE=cpu ./build/tests/tests
|
||||
- run:
|
||||
name: Build small binary
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
|
||||
make -j
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
@@ -132,7 +137,7 @@ jobs:
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.large.gen1
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -143,7 +148,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
@@ -207,7 +212,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
|
@@ -1,11 +1,11 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v17.0.6
|
||||
rev: v18.1.4
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.2.0
|
||||
rev: 24.4.2
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
|
@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
||||
@@ -15,7 +15,9 @@ MLX was developed with contributions from the following individuals:
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`.
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
114
CMakeLists.txt
114
CMakeLists.txt
@@ -15,12 +15,16 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.9.0)
|
||||
set(MLX_VERSION 0.14.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@@ -82,9 +86,13 @@ elseif (MLX_BUILD_METAL)
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||
set(MLX_METAL_VERSION METAL_3_1)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||
set(MLX_METAL_VERSION METAL_3_0)
|
||||
else()
|
||||
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
||||
endif()
|
||||
@@ -92,6 +100,7 @@ elseif (MLX_BUILD_METAL)
|
||||
FetchContent_Declare(
|
||||
metal_cpp
|
||||
URL ${METAL_CPP_URL}
|
||||
PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
@@ -101,55 +110,66 @@ elseif (MLX_BUILD_METAL)
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>
|
||||
)
|
||||
target_link_libraries(
|
||||
mlx
|
||||
mlx PUBLIC
|
||||
${METAL_LIB}
|
||||
${FOUNDATION_LIB}
|
||||
${QUARTZ_LIB})
|
||||
|
||||
add_compile_definitions(${MLX_METAL_VERSION})
|
||||
endif()
|
||||
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
if (MLX_BUILD_CPU)
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
set(BLA_VENDOR OpenBLAS)
|
||||
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||
endif()
|
||||
# Search and link with lapack.
|
||||
find_package(LAPACK REQUIRED)
|
||||
if (NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
/usr/local/opt/openblas/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old version
|
||||
# of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
if (NOT BLAS_FOUND)
|
||||
message(FATAL_ERROR "Must have BLAS installed")
|
||||
endif()
|
||||
# TODO find a cleaner way to do this
|
||||
find_path(BLAS_INCLUDE_DIRS cblas.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
$ENV{BLAS_HOME}/include)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
set(BLA_VENDOR OpenBLAS)
|
||||
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||
endif()
|
||||
# Search and link with lapack.
|
||||
find_package(LAPACK REQUIRED)
|
||||
if (NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
/usr/local/opt/openblas/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old version
|
||||
# of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
if (NOT BLAS_FOUND)
|
||||
message(FATAL_ERROR "Must have BLAS installed")
|
||||
endif()
|
||||
# TODO find a cleaner way to do this
|
||||
find_path(BLAS_INCLUDE_DIRS cblas.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
$ENV{BLAS_HOME}/include)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
|
||||
find_package(MPI)
|
||||
if (MPI_FOUND)
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
@@ -161,6 +181,14 @@ target_include_directories(
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
|
||||
FetchContent_Declare(fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL
|
||||
)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
|
||||
|
||||
if (MLX_BUILD_PYTHON_BINDINGS)
|
||||
message(STATUS "Building Python bindings.")
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
|
@@ -88,13 +88,13 @@ for more information on building the C++ and Python APIs from source.
|
||||
|
||||
## Contributing
|
||||
|
||||
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
|
||||
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
||||
on contributing to MLX. See the
|
||||
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
||||
information on building from source, and running tests.
|
||||
|
||||
We are grateful for all of [our
|
||||
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||
to MLX and wish to be acknowledged, please add your name to the list in your
|
||||
pull request.
|
||||
|
||||
|
@@ -17,14 +17,13 @@
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
|
||||
#define TIMEM(MSG, FUNC, ...) \
|
||||
std::cout << "Timing " \
|
||||
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
#define TIMEM(MSG, FUNC, ...) \
|
||||
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
|
||||
<< std::flush << std::setprecision(5) \
|
||||
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
|
||||
|
||||
template <typename F, typename... Args>
|
||||
double time_fn(F fn, Args... args) {
|
||||
double time_fn(F fn, Args&&... args) {
|
||||
// warmup
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
eval(fn(std::forward<Args>(args)...));
|
||||
|
123
benchmarks/python/conv1d_bench.py
Normal file
123
benchmarks/python/conv1d_bench.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_1D(strides=1, padding=0, groups=1):
|
||||
def mx_conv_1D(a, b):
|
||||
ys = []
|
||||
for _ in range(N_iter_func):
|
||||
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_1D
|
||||
|
||||
|
||||
def make_pt_conv_1D(strides=1, padding=0, groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_1D(a, b):
|
||||
ys = []
|
||||
for _ in range(N_iter_func):
|
||||
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_1D
|
||||
|
||||
|
||||
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
|
||||
scale = 1.0 / math.sqrt(wH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_1D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_1D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv1d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 5, 32, 1, 2, 1),
|
||||
(4, 32, 32, 5, 32, 1, 2, 2),
|
||||
(4, 32, 32, 5, 32, 1, 2, 4),
|
||||
(4, 32, 32, 5, 32, 1, 2, 8),
|
||||
(4, 32, 32, 5, 32, 1, 2, 8),
|
||||
(4, 32, 32, 5, 32, 1, 2, 16),
|
||||
(4, 32, 32, 5, 32, 1, 2, 32),
|
||||
(4, 32, 256, 5, 512, 1, 2, 2),
|
||||
(4, 32, 256, 5, 512, 1, 2, 128),
|
||||
(4, 32, 256, 5, 512, 1, 2, 256),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
|
||||
for N, iH, C, wH, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, iH, C, wH, O, strides, padding, np_dtype, groups
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
@@ -28,11 +28,11 @@ def bench(f, a, b):
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding)
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
@@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding)
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
@@ -53,11 +53,13 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
@@ -67,15 +69,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding)
|
||||
f_pt = make_pt_conv_2D(strides, padding)
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
@@ -84,7 +86,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
@@ -95,35 +97,40 @@ if __name__ == "__main__":
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
|
||||
for N, H, W, C, kH, kW, O, strides, padding in shapes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, np_dtype
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
||||
|
57
benchmarks/python/fft_bench.py
Normal file
57
benchmarks/python/fft_bench.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def bandwidth_gb(runtime_ms, system_size):
|
||||
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
|
||||
bytes_per_gb = 1e9
|
||||
ms_per_s = 1e3
|
||||
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
|
||||
|
||||
|
||||
def run_bench(system_size):
|
||||
def fft(x):
|
||||
out = mx.fft.fft(x)
|
||||
mx.eval(out)
|
||||
return out
|
||||
|
||||
bandwidths = []
|
||||
for k in range(4, 12):
|
||||
n = 2**k
|
||||
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
|
||||
x = x.astype(mx.complex64)
|
||||
mx.eval(x)
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
|
||||
|
||||
return bandwidths
|
||||
|
||||
|
||||
def time_fft():
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=int(2**22))
|
||||
|
||||
with mx.stream(mx.gpu):
|
||||
gpu_bandwidths = run_bench(system_size=int(2**29))
|
||||
|
||||
# plot bandwidths
|
||||
x = [2**k for k in range(4, 12)]
|
||||
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
|
||||
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
|
||||
plt.title("MLX FFT Benchmark")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig("fft_plot.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_fft()
|
36
cmake/metal.14.0.diff
Normal file
36
cmake/metal.14.0.diff
Normal file
@@ -0,0 +1,36 @@
|
||||
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
|
||||
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
|
||||
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
|
||||
@@ -62,6 +62,7 @@
|
||||
|
||||
uint64_t signaledValue() const;
|
||||
void setSignaledValue(uint64_t signaledValue);
|
||||
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
|
||||
};
|
||||
|
||||
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
|
||||
@@ -138,6 +139,11 @@
|
||||
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
|
||||
{
|
||||
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
|
||||
+}
|
||||
+
|
||||
+// method: waitUntilSignaledValue
|
||||
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
|
||||
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
|
||||
}
|
||||
|
||||
// static method: alloc
|
||||
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
|
||||
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
|
||||
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
|
||||
@@ -1906,6 +1906,9 @@
|
||||
"setShouldMaximizeConcurrentCompilation:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
|
||||
"setSignaledValue:");
|
||||
+_MTL_PRIVATE_DEF_SEL(
|
||||
+ waitUntilSignaledValue_timeoutMS_,
|
||||
+ "waitUntilSignaledValue:timeoutMS:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSize_,
|
||||
"setSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSlice_,
|
36
cmake/metal.14.2.diff
Normal file
36
cmake/metal.14.2.diff
Normal file
@@ -0,0 +1,36 @@
|
||||
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
|
||||
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
|
||||
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
|
||||
@@ -62,6 +62,7 @@
|
||||
|
||||
uint64_t signaledValue() const;
|
||||
void setSignaledValue(uint64_t signaledValue);
|
||||
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
|
||||
};
|
||||
|
||||
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
|
||||
@@ -138,6 +139,11 @@
|
||||
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
|
||||
{
|
||||
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
|
||||
+}
|
||||
+
|
||||
+// method: waitUntilSignaledValue
|
||||
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
|
||||
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
|
||||
}
|
||||
|
||||
// static method: alloc
|
||||
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
|
||||
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
|
||||
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
|
||||
@@ -1918,6 +1918,9 @@
|
||||
"setShouldMaximizeConcurrentCompilation:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
|
||||
"setSignaledValue:");
|
||||
+_MTL_PRIVATE_DEF_SEL(
|
||||
+ waitUntilSignaledValue_timeoutMS_,
|
||||
+ "waitUntilSignaledValue:timeoutMS:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSize_,
|
||||
"setSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSlice_,
|
50
docs/Doxyfile
Normal file
50
docs/Doxyfile
Normal file
@@ -0,0 +1,50 @@
|
||||
################################################################################
|
||||
# Primary project setup. #
|
||||
################################################################################
|
||||
|
||||
PROJECT_NAME = "MLX"
|
||||
OUTPUT_DIRECTORY = build
|
||||
XML_OUTPUT = xml
|
||||
HTML_OUTPUT = html
|
||||
STRIP_FROM_PATH = ../
|
||||
INPUT = ../mlx
|
||||
FILE_PATTERNS = *.h
|
||||
EXCLUDE_PATTERNS = */private/*
|
||||
CREATE_SUBDIRS = NO
|
||||
FULL_PATH_NAMES = YES
|
||||
RECURSIVE = YES
|
||||
GENERATE_HTML = YES
|
||||
GENERATE_LATEX = NO
|
||||
GENERATE_XML = YES
|
||||
XML_PROGRAMLISTING = YES
|
||||
|
||||
################################################################################
|
||||
# Doxygen preprocessor / parser control. #
|
||||
################################################################################
|
||||
|
||||
ENABLE_PREPROCESSING = YES
|
||||
MACRO_EXPANSION = YES
|
||||
EXPAND_ONLY_PREDEF = NO
|
||||
SKIP_FUNCTION_MACROS = NO
|
||||
|
||||
################################################################################
|
||||
# Compound extraction control. #
|
||||
################################################################################
|
||||
|
||||
EXTRACT_ALL = YES
|
||||
EXTRACT_PACKAGE = YES
|
||||
EXTRACT_STATIC = YES
|
||||
CASE_SENSE_NAMES = NO
|
||||
|
||||
################################################################################
|
||||
# Docstring control / customization. #
|
||||
################################################################################
|
||||
|
||||
JAVADOC_AUTOBRIEF = YES
|
||||
|
||||
################################################################################
|
||||
# Warning suppression. #
|
||||
################################################################################
|
||||
|
||||
QUIET = YES
|
||||
WARN_IF_UNDOCUMENTED = NO
|
@@ -2,12 +2,16 @@
|
||||
|
||||
### Setup (do once)
|
||||
|
||||
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
|
||||
for example with `conda`:
|
||||
Install Doxygen:
|
||||
|
||||
```
|
||||
conda install sphinx
|
||||
pip install sphinx-book-theme
|
||||
brew install doxygen
|
||||
```
|
||||
|
||||
Install Python packages:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Build
|
||||
@@ -15,7 +19,7 @@ pip install sphinx-book-theme
|
||||
Build the docs from `mlx/docs/`
|
||||
|
||||
```
|
||||
make html
|
||||
doxygen && make html
|
||||
```
|
||||
|
||||
View the docs by running a server in `mlx/docs/build/html/`:
|
||||
|
3
docs/requirements.txt
Normal file
3
docs/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
sphinx
|
||||
breathe
|
||||
sphinx-book-theme
|
20
docs/src/_templates/nn-module-template.rst
Normal file
20
docs/src/_templates/nn-module-template.rst
Normal file
@@ -0,0 +1,20 @@
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. currentmodule:: {{ module }}
|
||||
|
||||
.. autoclass:: {{ objname }}
|
||||
|
||||
{% block methods %}
|
||||
|
||||
{% if methods %}
|
||||
.. rubric:: {{ _('Methods') }}
|
||||
|
||||
.. autosummary::
|
||||
{% for item in methods %}
|
||||
{%- if item not in inherited_members and item != "__init__" %}
|
||||
~{{ name }}.{{ item }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
@@ -22,6 +22,7 @@ extensions = [
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.napoleon",
|
||||
"breathe",
|
||||
]
|
||||
|
||||
python_use_unqualified_type_names = True
|
||||
@@ -33,6 +34,9 @@ intersphinx_mapping = {
|
||||
"numpy": ("https://numpy.org/doc/stable/", None),
|
||||
}
|
||||
|
||||
breathe_projects = {"mlx": "../build/xml"}
|
||||
breathe_default_project = "mlx"
|
||||
|
||||
templates_path = ["_templates"]
|
||||
html_static_path = ["_static"]
|
||||
source_suffix = ".rst"
|
||||
|
@@ -3,4 +3,5 @@
|
||||
Operations
|
||||
==========
|
||||
|
||||
|
||||
.. doxygengroup:: ops
|
||||
:content-only:
|
||||
|
@@ -1,24 +1,16 @@
|
||||
Developer Documentation
|
||||
=======================
|
||||
Custom Extensions in MLX
|
||||
========================
|
||||
|
||||
MLX provides a open and flexible backend to which users may add operations
|
||||
and specialized implementations without much hassle. While the library supplies
|
||||
efficient operations that can be used and composed for any number of
|
||||
applications, there may arise cases where new functionalities or highly
|
||||
optimized implementations are needed. For such cases, you may design and
|
||||
implement your own operations that link to and build on top of :mod:`mlx.core`.
|
||||
We will introduce the inner-workings of MLX and go over a simple example to
|
||||
learn the steps involved in adding new operations to MLX with your own CPU
|
||||
and GPU implementations.
|
||||
You can extend MLX with custom operations on the CPU or GPU. This guide
|
||||
explains how to do that with a simple example.
|
||||
|
||||
Introducing the Example
|
||||
-----------------------
|
||||
|
||||
Let's say that you would like an operation that takes in two arrays,
|
||||
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
|
||||
respectively, and then adds them together to get the result
|
||||
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
||||
writing out a function as follows:
|
||||
Let's say you would like an operation that takes in two arrays, ``x`` and
|
||||
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
|
||||
and then adds them together to get the result ``z = alpha * x + beta * y``.
|
||||
You can do that in MLX directly:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -27,44 +19,35 @@ writing out a function as follows:
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
This function performs that operation while leaving the implementations and
|
||||
differentiation to MLX.
|
||||
This function performs that operation while leaving the implementation and
|
||||
function transformations to MLX.
|
||||
|
||||
However, you work with vector math libraries often and realize that the
|
||||
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
|
||||
You would really like the part of your applications that does this operation
|
||||
on the CPU to be very fast - so you decide that you want it to rely on the
|
||||
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
|
||||
our assumptions on to you, let's also assume that you want to learn how to add
|
||||
your own implementation for the gradients of your new operation while going
|
||||
over the ins-and-outs of the MLX framework.
|
||||
However you may need to customize the underlying implementation, perhaps to
|
||||
make it faster or for custom differentiation. In this tutorial we will go
|
||||
through adding custom extensions. It will cover:
|
||||
|
||||
Well, what a coincidence! You are in the right place. Over the course of this
|
||||
example, we will learn:
|
||||
|
||||
* The structure of the MLX library from the frontend API to the backend implementations.
|
||||
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
|
||||
* How to implement your own GPU implementation using metal.
|
||||
* How to add your own ``vjp`` and ``jvp``.
|
||||
* How to build your implementations, link them to MLX, and bind them to python.
|
||||
* The structure of the MLX library.
|
||||
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
|
||||
* Implementing a GPU operation using metal.
|
||||
* Adding the ``vjp`` and ``jvp`` function transformation.
|
||||
* Building a custom extension and binding it to python.
|
||||
|
||||
Operations and Primitives
|
||||
-------------------------
|
||||
|
||||
In one sentence, operations in MLX build the computation graph, and primitives
|
||||
provide the rules for evaluation and transformations of said graph. Let's start
|
||||
by discussing operations in more detail.
|
||||
Operations in MLX build the computation graph. Primitives provide the rules for
|
||||
evaluating and transforming the graph. Let's start by discussing operations in
|
||||
more detail.
|
||||
|
||||
Operations
|
||||
^^^^^^^^^^^
|
||||
|
||||
Operations are the frontend functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
|
||||
operations in the Python API (:ref:`ops`).
|
||||
Operations are the front-end functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
||||
|
||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
|
||||
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
|
||||
C++ API:
|
||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
|
||||
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
||||
C++:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -83,10 +66,7 @@ C++ API:
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
|
||||
This operation itself can call other operations within it if needed. So, the
|
||||
simplest way to go about implementing this operation would be do so in terms
|
||||
of existing operations.
|
||||
The simplest way to this operation is in terms of existing operations:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -100,25 +80,23 @@ of existing operations.
|
||||
// Scale x and y on the provided stream
|
||||
auto ax = multiply(array(alpha), x, s);
|
||||
auto by = multiply(array(beta), y, s);
|
||||
|
||||
|
||||
// Add and return
|
||||
return add(ax, by, s);
|
||||
}
|
||||
|
||||
However, as we discussed earlier, this is not our goal. The operations themselves
|
||||
do not contain the implementations that act on the data, nor do they contain the
|
||||
rules of transformations. Rather, they are an easy to use interface that build
|
||||
on top of the building blocks we call :class:`Primitive`.
|
||||
The operations themselves do not contain the implementations that act on the
|
||||
data, nor do they contain the rules of transformations. Rather, they are an
|
||||
easy to use interface that use :class:`Primitive` building blocks.
|
||||
|
||||
Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create an output given a set of input :class:`array` . Further,
|
||||
a :class:`Primitive` is a class that contains rules on how it is evaluated
|
||||
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
|
||||
``jvp``. These words on their own can be a bit abstract, so lets take a step
|
||||
back and go to our example to give ourselves a more concrete image.
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||
more concrete:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -134,11 +112,15 @@ back and go to our example to give ourselves a more concrete image.
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) override;
|
||||
void eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array jvp(
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
@@ -147,7 +129,8 @@ back and go to our example to give ourselves a more concrete image.
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) override;
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself across
|
||||
@@ -155,7 +138,7 @@ back and go to our example to give ourselves a more concrete image.
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<array, int> vmap(
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
@@ -175,22 +158,22 @@ back and go to our example to give ourselves a more concrete image.
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class and
|
||||
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
|
||||
``beta`` as parameters. It then provides implementations of how the array ``out``
|
||||
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
|
||||
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
|
||||
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
||||
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
|
||||
implementations of how the output array is produced given the inputs through
|
||||
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
|
||||
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
|
||||
:meth:`Axpby::vmap`.
|
||||
|
||||
Using the Primitives
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
Using the Primitive
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Operations can use this :class:`Primitive` to add a new :class:`array` to
|
||||
the computation graph. An :class:`array` can be constructed by providing its
|
||||
data type, shape, the :class:`Primitive` that computes it, and the
|
||||
:class:`array` inputs that are passed to the primitive.
|
||||
Operations can use this :class:`Primitive` to add a new :class:`array` to the
|
||||
computation graph. An :class:`array` can be constructed by providing its data
|
||||
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
|
||||
inputs that are passed to the primitive.
|
||||
|
||||
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
||||
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -238,27 +221,26 @@ This operation now handles the following:
|
||||
Implementing the Primitive
|
||||
--------------------------
|
||||
|
||||
No computation happens when we call the operation alone. In effect, the
|
||||
operation only builds the computation graph. When we evaluate the output
|
||||
array, MLX schedules the execution of the computation graph, and calls
|
||||
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
|
||||
stream/device specified by the user.
|
||||
No computation happens when we call the operation alone. The operation only
|
||||
builds the computation graph. When we evaluate the output array, MLX schedules
|
||||
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
|
||||
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
|
||||
|
||||
.. warning::
|
||||
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
|
||||
no memory has been allocated for the output array. It falls on the implementation
|
||||
of these functions to allocate memory as needed
|
||||
of these functions to allocate memory as needed.
|
||||
|
||||
Implementing the CPU Backend
|
||||
Implementing the CPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's start by trying to implement a naive and generic version of
|
||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||
Let's start by implementing a naive and generic version of
|
||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||
|
||||
Our naive method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
Our naive method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -296,19 +278,19 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
}
|
||||
}
|
||||
|
||||
Now, we would like our implementation to be able to do this pointwise operation
|
||||
for all incoming floating point arrays. Accordingly, we add dispatches for
|
||||
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
|
||||
if we encounter an unexpected type.
|
||||
Our implementation should work for all incoming floating point arrays.
|
||||
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||
``complex64``. We throw an error if we encounter an unexpected type.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
@@ -321,28 +303,26 @@ if we encounter an unexpected type.
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
"[Axpby] Only supports floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
We have a fallback implementation! Now, to do what we are really here to do.
|
||||
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
|
||||
framework? Well, there are 3 complications to keep in mind:
|
||||
This is good as a fallback implementation. We can use the ``axpby`` routine
|
||||
provided by the Accelerate_ framework for a faster implementation in certain
|
||||
cases:
|
||||
|
||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
||||
floats. We can only direct to it for ``float32`` types
|
||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
|
||||
have fixed strides between them. Possibly due to broadcasts and transposes,
|
||||
we aren't guaranteed that the inputs fit this requirement. We can
|
||||
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
|
||||
column contiguous.
|
||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
|
||||
MLX expects to write out the answer to a new array. We must copy the elements
|
||||
of ``y`` into the output array and use that as an input to ``axpby``
|
||||
floats. We can only use it for ``float32`` types.
|
||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
|
||||
elements have fixed strides between them. We only direct to Accelerate
|
||||
if both ``x`` and ``y`` are row contiguous or column contiguous.
|
||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
|
||||
MLX expects to write the output to a new array. We must copy the elements
|
||||
of ``y`` into the output and use that as an input to ``axpby``.
|
||||
|
||||
Let's write out an implementation that uses Accelerate in the right conditions.
|
||||
It must simply allocate data for the output, copy elements of ``y`` into it,
|
||||
and then call the :meth:`catlas_saxpby` from accelerate.
|
||||
Let's write an implementation that uses Accelerate in the right conditions.
|
||||
It allocates data for the output, copies ``y`` into it, and then calls the
|
||||
:func:`catlas_saxpby` from accelerate.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -356,17 +336,7 @@ and then call the :meth:`catlas_saxpby` from accelerate.
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
|
||||
// This specialization requires both x and y be contiguous in the same mode
|
||||
// i.e: corresponding linear indices in both point to corresponding elements
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
||||
y.data_size(),
|
||||
y.strides(),
|
||||
y.flags());
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
@@ -389,18 +359,20 @@ and then call the :meth:`catlas_saxpby` from accelerate.
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
Great! But what about the inputs that do not fit the criteria for accelerate?
|
||||
Luckily, we can always just direct back to :meth:`Axpby::eval`.
|
||||
|
||||
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
||||
For inputs that do not fit the criteria for accelerate, we fall back to
|
||||
:meth:`Axpby::eval`. With this in mind, let's finish our
|
||||
:meth:`Axpby::eval_cpu`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
@@ -410,35 +382,33 @@ With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, out);
|
||||
// Fall back to common back-end if specializations are not available
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
We have now hit a milestone! Just this much is enough to run the operation
|
||||
:meth:`axpby` on a CPU stream!
|
||||
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
||||
you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||
|
||||
If you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||
|
||||
Implementing the GPU Backend
|
||||
Implementing the GPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
||||
all GPU kernels in MLX are written using metal.
|
||||
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
||||
GPU kernels in MLX are written using Metal.
|
||||
|
||||
.. note::
|
||||
|
||||
Here are some helpful resources if you are new to metal!
|
||||
Here are some helpful resources if you are new to Metal:
|
||||
|
||||
* A walkthrough of the metal compute pipeline: `Metal Example`_
|
||||
* Documentation for metal shading language: `Metal Specification`_
|
||||
* Using metal from C++: `Metal-cpp`_
|
||||
|
||||
Let's keep the GPU algorithm simple. We will launch exactly as many threads
|
||||
as there are elements in the output. Each thread will pick the element it needs
|
||||
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
|
||||
element in the output.
|
||||
Let's keep the GPU kernel simple. We will launch exactly as many threads as
|
||||
there are elements in the output. Each thread will pick the element it needs
|
||||
from ``x`` and ``y``, do the point-wise operation, and update its assigned
|
||||
element in the output.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -457,15 +427,14 @@ element in the output.
|
||||
// Convert linear indices to offsets in array
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||
|
||||
|
||||
// Do the operation and update the output
|
||||
out[index] =
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||
}
|
||||
|
||||
We then need to instantiate this template for all floating point types and give
|
||||
each instantiation a unique host name so we can identify the right kernel for
|
||||
each data type.
|
||||
each instantiation a unique host name so we can identify it.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -488,29 +457,21 @@ each data type.
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
|
||||
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
||||
will see later in :ref:`Building with CMake`. In the following example, we
|
||||
assume that the library ``mlx_ext.metallib`` will always be co-located with
|
||||
the executable/ shared-library calling the :meth:`register_library` function.
|
||||
The :meth:`register_library` function takes the library's name and potential
|
||||
path (or in this case, a function that can produce the path of the metal
|
||||
library) and tries to load that library if it hasn't already been registered
|
||||
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
|
||||
it is important to package your C++ library with the metal library. We will
|
||||
go over this process in more detail later.
|
||||
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions
|
||||
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
below.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Each primitive carries the stream it should execute on
|
||||
// and each stream carries its device identifiers
|
||||
@@ -518,10 +479,10 @@ below.
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
// Resolve name of kernel
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
@@ -533,7 +494,7 @@ below.
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
@@ -542,17 +503,17 @@ below.
|
||||
size_t nelem = out.size();
|
||||
|
||||
// Encode input arrays to kernel
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, y, 1);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(y, 1);
|
||||
|
||||
// Encode output arrays to kernel
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
@@ -570,33 +531,30 @@ below.
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
|
||||
A few things to note about MLX and metal before moving on. MLX keeps track
|
||||
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
|
||||
to give us the active metal compute command encoder instead of building a
|
||||
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
||||
until some specified limit is hit or the compute encoder needs to be flushed
|
||||
for synchronization. MLX also handles enqueuing and committing the associated
|
||||
command buffers as needed. We suggest taking a deeper dive into
|
||||
:class:`metal::Device` if you would like to study this routine further.
|
||||
A few things to note about MLX and Metal before moving on. MLX keeps track of
|
||||
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
|
||||
associated. We rely on :meth:`d.get_command_encoder` to give us the active
|
||||
metal compute command encoder instead of building a new one and calling
|
||||
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
|
||||
pipelines) to the active command buffer until some specified limit is hit or
|
||||
the command buffer needs to be flushed for synchronization.
|
||||
|
||||
Primitive Transforms
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Now that we have come this far, let's also learn how to add implementations to
|
||||
transformations in a :class:`Primitive`. These transformations can be built on
|
||||
top of our operations, including the one we just defined now. Which then gives
|
||||
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
Next, let's add implementations for transformations in a :class:`Primitive`.
|
||||
These transformations can be built on top of other operations, including the
|
||||
one we just defined:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array Axpby::jvp(
|
||||
std::vector<array> Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
@@ -611,12 +569,12 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return multiply(scale_arr, tangents[0], stream());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -625,34 +583,35 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<int>& /* unused */) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotan.dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotan, stream()));
|
||||
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
Finally, you need not have a transformation fully defined to start using your
|
||||
own :class:`Primitive`.
|
||||
Note, a transformation does not need to be fully defined to start using
|
||||
the :class:`Primitive`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<array, int> Axpby::vmap(
|
||||
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
throw std::runtime_error("[Axpby] vmap not implemented.");
|
||||
}
|
||||
|
||||
Building and Binding
|
||||
--------------------
|
||||
|
||||
Let's look at the overall directory structure first.
|
||||
Let's look at the overall directory structure first.
|
||||
|
||||
| extensions
|
||||
| ├── axpby
|
||||
@@ -666,40 +625,39 @@ Let's look at the overall directory structure first.
|
||||
| └── setup.py
|
||||
|
||||
* ``extensions/axpby/`` defines the C++ extension library
|
||||
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||
associated python package
|
||||
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||
python bindings
|
||||
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||
associated Python package
|
||||
* ``extensions/bindings.cpp`` provides Python bindings for our operation
|
||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||
Python bindings
|
||||
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
|
||||
the python package
|
||||
the Python package
|
||||
|
||||
Binding to Python
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
|
||||
We use nanobind_ to build a Python API for the C++ library. Since bindings for
|
||||
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
|
||||
already provided, adding our :meth:`axpby` is simple!
|
||||
already provided, adding our :meth:`axpby` is simple.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
py::pos_only(),
|
||||
"alpha"_a,
|
||||
"beta"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
R"(
|
||||
Scale and sum two vectors element-wise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
Inputs are upcasted to floats if needed
|
||||
|
||||
@@ -711,17 +669,17 @@ already provided, adding our :meth:`axpby` is simple!
|
||||
|
||||
Returns:
|
||||
array: ``alpha * x + beta * y``
|
||||
)pbdoc");
|
||||
)");
|
||||
}
|
||||
|
||||
Most of the complexity in the above example comes from additional bells and
|
||||
Most of the complexity in the above example comes from additional bells and
|
||||
whistles such as the literal names and doc-strings.
|
||||
|
||||
.. warning::
|
||||
|
||||
:mod:`mlx.core` needs to be imported before importing
|
||||
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
|
||||
ensure that the casters for :mod:`mlx.core` components like
|
||||
:mod:`mlx.core` must be imported before importing
|
||||
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
|
||||
ensure that the casters for :mod:`mlx.core` components like
|
||||
:class:`mlx.core.array` are available.
|
||||
|
||||
.. _Building with CMake:
|
||||
@@ -729,8 +687,8 @@ whistles such as the literal names and doc-strings.
|
||||
Building with CMake
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Building the C++ extension library itself is simple, it only requires that you
|
||||
``find_package(MLX CONFIG)`` and then link it to your library.
|
||||
Building the C++ extension library only requires that you ``find_package(MLX
|
||||
CONFIG)`` and then link it to your library.
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
@@ -752,12 +710,12 @@ Building the C++ extension library itself is simple, it only requires that you
|
||||
# Link to mlx
|
||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||
|
||||
We also need to build the attached metal library. For convenience, we provide a
|
||||
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
||||
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
||||
automatically imported with MLX package).
|
||||
We also need to build the attached Metal library. For convenience, we provide a
|
||||
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
||||
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
||||
automatically imported with MLX package).
|
||||
|
||||
Here is what that looks like in practice!
|
||||
Here is what that looks like in practice:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
@@ -779,27 +737,29 @@ Here is what that looks like in practice!
|
||||
|
||||
endif()
|
||||
|
||||
Finally, we build the Pybind11_ bindings
|
||||
Finally, we build the nanobind_ bindings
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
pybind11_add_module(
|
||||
mlx_sample_extensions
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
nanobind_add_module(
|
||||
_ext
|
||||
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||
NB_DOMAIN mlx
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
)
|
||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
|
||||
endif()
|
||||
|
||||
Building with ``setuptools``
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Once we have set out the CMake build rules as described above, we can use the
|
||||
build utilities defined in :mod:`mlx.extension` for a simple build process.
|
||||
build utilities defined in :mod:`mlx.extension`:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
from mlx import extension
|
||||
from setuptools import setup
|
||||
@@ -809,48 +769,50 @@ build utilities defined in :mod:`mlx.extension` for a simple build process.
|
||||
name="mlx_sample_extensions",
|
||||
version="0.0.0",
|
||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages = ["mlx_sample_extensions"],
|
||||
package_dir = {"": "mlx_sample_extensions"},
|
||||
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
|
||||
packages=["mlx_sample_extensions"],
|
||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||
extras_require={"dev":[]},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.7",
|
||||
python_requires=">=3.8",
|
||||
)
|
||||
|
||||
.. note::
|
||||
We treat ``extensions/mlx_sample_extensions`` as the package directory
|
||||
even though it only contains a ``__init__.py`` to ensure the following:
|
||||
|
||||
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
|
||||
* The C++ extension library and the metal library are co-located with the python
|
||||
bindings and copied together if the package is installed
|
||||
|
||||
You can build inplace for development using
|
||||
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
|
||||
* The C++ extension library and the metal library are co-located with the python
|
||||
bindings and copied together if the package is installed
|
||||
|
||||
To build the package, first install the build dependencies with ``pip install
|
||||
-r requirements.txt``. You can then build inplace for development using
|
||||
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
|
||||
|
||||
This will result in a directory structure as follows:
|
||||
This results in the directory structure:
|
||||
|
||||
| extensions
|
||||
| ├── mlx_sample_extensions
|
||||
| │ ├── __init__.py
|
||||
| │ ├── libmlx_ext.dylib # C++ extension library
|
||||
| │ ├── mlx_ext.metallib # Metal library
|
||||
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
|
||||
| │ └── _ext.cpython-3x-darwin.so # Python Binding
|
||||
| ...
|
||||
|
||||
When you try to install using the command ``python -m pip install .``
|
||||
(in ``extensions/``), the package will be installed with the same structure as
|
||||
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
||||
copied along with the python binding since they are specified as ``package_data``.
|
||||
When you try to install using the command ``python -m pip install .`` (in
|
||||
``extensions/``), the package will be installed with the same structure as
|
||||
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
|
||||
copied along with the Python binding since they are specified as
|
||||
``package_data``.
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
After installing the extension as described above, you should be able to simply
|
||||
import the python package and play with it as you would any other MLX operation!
|
||||
After installing the extension as described above, you should be able to simply
|
||||
import the Python package and play with it as you would any other MLX operation.
|
||||
|
||||
Let's looks at a simple script and it's results!
|
||||
Let's look at a simple script and its results:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -863,7 +825,7 @@ Let's looks at a simple script and it's results!
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correctness: {mx.all(c == 6.0).item()}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
@@ -874,12 +836,12 @@ Output:
|
||||
c correctness: True
|
||||
|
||||
Results
|
||||
^^^^^^^^^^^^^^^^
|
||||
^^^^^^^
|
||||
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_sample_extensions import axpby
|
||||
@@ -898,7 +860,7 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
||||
alpha = 4.0
|
||||
beta = 2.0
|
||||
|
||||
mx.eval((x, y))
|
||||
mx.eval(x, y)
|
||||
|
||||
def bench(f):
|
||||
# Warm up
|
||||
@@ -919,30 +881,23 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
||||
|
||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
||||
|
||||
Results:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Simple axpby: 0.114 s | Custom axpby: 0.109 s
|
||||
|
||||
We see some modest improvements right away!
|
||||
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
|
||||
modest improvements right away!
|
||||
|
||||
This operation is now good to be used to build other operations, in
|
||||
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
|
||||
:meth:`grad`!
|
||||
:meth:`grad`.
|
||||
|
||||
Scripts
|
||||
-------
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx <code>`_.
|
||||
|
||||
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
|
||||
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
|
||||
|
||||
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
||||
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
||||
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
||||
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
||||
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
|
||||
.. _nanobind: https://nanobind.readthedocs.io/en/latest/
|
||||
|
@@ -1,29 +1,45 @@
|
||||
Metal Debugger
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Profiling is a key step for performance optimization. You can build MLX with
|
||||
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and optimization
|
||||
workflow. The ``MLX_METAL_DEBUG`` debug option:
|
||||
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
|
||||
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
|
||||
|
||||
* Records source during Metal compilation, for later inspection while
|
||||
debugging.
|
||||
* Labels Metal objects such as command queues, improving capture readability.
|
||||
|
||||
The ``metal::start_capture`` function initiates a capture of all MLX GPU work.
|
||||
To build with debugging enabled in Python prepend
|
||||
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
|
||||
|
||||
.. code-block:: C++
|
||||
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
|
||||
work.
|
||||
|
||||
int main() {
|
||||
metal::start_capture("/Users/Jane/Developer/MLX.gputrace");
|
||||
.. note::
|
||||
|
||||
auto a = arange(10.f, 20.f, 1.f, float32);
|
||||
auto b = arange(30.f, 40.f, 1.f, float32);
|
||||
auto c = add(a, b);
|
||||
To capture a GPU trace you must run the application with
|
||||
``MTL_CAPTURE_ENABLED=1``.
|
||||
|
||||
eval(c);
|
||||
.. code-block:: python
|
||||
|
||||
metal::stop_capture();
|
||||
}
|
||||
import mlx.core as mx
|
||||
|
||||
a = mx.random.uniform(shape=(512, 512))
|
||||
b = mx.random.uniform(shape=(512, 512))
|
||||
mx.eval(a, b)
|
||||
|
||||
trace_file = "mlx_trace.gputrace"
|
||||
|
||||
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
|
||||
# that the path trace_file does not already exist.
|
||||
mx.metal.start_capture(trace_file)
|
||||
|
||||
for _ in range(10):
|
||||
mx.eval(mx.add(a, b))
|
||||
|
||||
mx.metal.stop_capture()
|
||||
|
||||
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
|
||||
has a great overview of all operations. Checkout the `Metal debugger
|
||||
@@ -35,8 +51,8 @@ documentation`_ for more information.
|
||||
Xcode Workflow
|
||||
--------------
|
||||
|
||||
You can skip saving to a path by running within Xcode. First, generate an Xcode
|
||||
project using CMake.
|
||||
You can skip saving to a path by running within Xcode. First, generate an
|
||||
Xcode project using CMake.
|
||||
|
||||
.. code-block::
|
||||
|
||||
|
@@ -74,7 +74,7 @@ Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install git+https://github.com/wjakob/nanobind.git
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
|
||||
Then simply build and install MLX using pip:
|
||||
|
||||
@@ -120,7 +120,7 @@ Create a build directory and run CMake and make:
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. && make -j
|
||||
cmake .. && make -j
|
||||
|
||||
Run tests with:
|
||||
|
||||
@@ -139,7 +139,7 @@ directory as the executable statically linked to ``libmlx.a`` or the
|
||||
preprocessor constant ``METAL_PATH`` should be defined at build time and it
|
||||
should point to the path to the built metal library.
|
||||
|
||||
.. list-table:: Build Options
|
||||
.. list-table:: Build Options
|
||||
:widths: 25 8
|
||||
:header-rows: 1
|
||||
|
||||
@@ -153,33 +153,67 @@ should point to the path to the built metal library.
|
||||
- OFF
|
||||
* - MLX_BUILD_METAL
|
||||
- ON
|
||||
* - MLX_BUILD_CPU
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
||||
* - MLX_METAL_DEBUG
|
||||
- OFF
|
||||
|
||||
* - MLX_BUILD_SAFETENSORS
|
||||
- ON
|
||||
* - MLX_BUILD_GGUF
|
||||
- ON
|
||||
* - MLX_METAL_JIT
|
||||
- OFF
|
||||
|
||||
.. note::
|
||||
|
||||
If you have multiple Xcode installations and wish to use
|
||||
a specific one while building, you can do so by adding the
|
||||
following environment variable before building
|
||||
If you have multiple Xcode installations and wish to use
|
||||
a specific one while building, you can do so by adding the
|
||||
following environment variable before building
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
|
||||
|
||||
Further, you can use the following command to find out which
|
||||
Further, you can use the following command to find out which
|
||||
macOS SDK will be used
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
|
||||
and `BUILD_SHARED_LIBS=ON`.
|
||||
|
||||
The MLX CMake build has several additional options to make smaller binaries.
|
||||
For example, if you don't need the CPU backend or support for safetensors and
|
||||
GGUF, you can do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
cmake ..
|
||||
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
|
||||
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
|
||||
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists accross reboots.
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Metal not found
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -201,7 +235,7 @@ Then set the active developer directory:
|
||||
|
||||
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
|
||||
|
||||
x86 Shell
|
||||
x86 Shell
|
||||
~~~~~~~~~
|
||||
|
||||
.. _build shell:
|
||||
|
@@ -16,3 +16,4 @@ Devices and Streams
|
||||
new_stream
|
||||
set_default_stream
|
||||
stream
|
||||
synchronize
|
||||
|
@@ -8,5 +8,8 @@ Linear Algebra
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
inv
|
||||
norm
|
||||
cholesky
|
||||
qr
|
||||
svd
|
||||
|
@@ -3,12 +3,17 @@ Metal
|
||||
|
||||
.. currentmodule:: mlx.core.metal
|
||||
|
||||
.. autosummary::
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
is_available
|
||||
device_info
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
clear_cache
|
||||
start_capture
|
||||
stop_capture
|
||||
|
@@ -173,6 +173,7 @@ In detail:
|
||||
:toctree: _autosummary
|
||||
|
||||
value_and_grad
|
||||
quantize
|
||||
|
||||
.. toctree::
|
||||
|
||||
|
@@ -15,6 +15,7 @@ Layers
|
||||
BatchNorm
|
||||
Conv1d
|
||||
Conv2d
|
||||
Conv3d
|
||||
Dropout
|
||||
Dropout2d
|
||||
Dropout3d
|
||||
@@ -31,6 +32,7 @@ Layers
|
||||
Mish
|
||||
MultiHeadAttention
|
||||
PReLU
|
||||
QuantizedEmbedding
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
@@ -43,4 +45,4 @@ Layers
|
||||
Softshrink
|
||||
Step
|
||||
Transformer
|
||||
Upsample
|
||||
Upsample
|
||||
|
@@ -5,13 +5,14 @@ Operations
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
abs
|
||||
add
|
||||
addmm
|
||||
all
|
||||
allclose
|
||||
allclose
|
||||
any
|
||||
arange
|
||||
arccos
|
||||
@@ -19,19 +20,27 @@ Operations
|
||||
arcsin
|
||||
arcsinh
|
||||
arctan
|
||||
arctan2
|
||||
arctanh
|
||||
argmax
|
||||
argmin
|
||||
argpartition
|
||||
argsort
|
||||
array_equal
|
||||
as_strided
|
||||
atleast_1d
|
||||
atleast_2d
|
||||
atleast_3d
|
||||
bitwise_and
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
conv1d
|
||||
conv2d
|
||||
@@ -42,6 +51,7 @@ Operations
|
||||
cummin
|
||||
cumprod
|
||||
cumsum
|
||||
degrees
|
||||
dequantize
|
||||
diag
|
||||
diagonal
|
||||
@@ -51,12 +61,15 @@ Operations
|
||||
erf
|
||||
erfinv
|
||||
exp
|
||||
expm1
|
||||
expand_dims
|
||||
eye
|
||||
flatten
|
||||
floor
|
||||
floor_divide
|
||||
full
|
||||
gather_mm
|
||||
gather_qmm
|
||||
greater
|
||||
greater_equal
|
||||
identity
|
||||
@@ -66,6 +79,8 @@ Operations
|
||||
isnan
|
||||
isneginf
|
||||
isposinf
|
||||
issubdtype
|
||||
left_shift
|
||||
less
|
||||
less_equal
|
||||
linspace
|
||||
@@ -83,22 +98,28 @@ Operations
|
||||
max
|
||||
maximum
|
||||
mean
|
||||
meshgrid
|
||||
min
|
||||
minimum
|
||||
moveaxis
|
||||
multiply
|
||||
negative
|
||||
not_equal
|
||||
ones
|
||||
ones_like
|
||||
outer
|
||||
partition
|
||||
pad
|
||||
power
|
||||
prod
|
||||
quantize
|
||||
quantized_matmul
|
||||
radians
|
||||
reciprocal
|
||||
remainder
|
||||
repeat
|
||||
reshape
|
||||
right_shift
|
||||
round
|
||||
rsqrt
|
||||
save
|
||||
@@ -117,6 +138,7 @@ Operations
|
||||
square
|
||||
squeeze
|
||||
stack
|
||||
std
|
||||
stop_gradient
|
||||
subtract
|
||||
sum
|
||||
@@ -128,6 +150,7 @@ Operations
|
||||
tensordot
|
||||
tile
|
||||
topk
|
||||
trace
|
||||
transpose
|
||||
tri
|
||||
tril
|
||||
|
@@ -1,5 +1,7 @@
|
||||
.. _optimizers:
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
Optimizers
|
||||
==========
|
||||
|
||||
@@ -34,3 +36,8 @@ model's parameters and the **optimizer state**.
|
||||
optimizers/optimizer
|
||||
optimizers/common_optimizers
|
||||
optimizers/schedulers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
clip_grad_norm
|
||||
|
@@ -38,6 +38,7 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||
gumbel
|
||||
key
|
||||
normal
|
||||
multivariate_normal
|
||||
randint
|
||||
seed
|
||||
split
|
||||
|
@@ -19,3 +19,5 @@ return python trees will be using the default python ``dict``, ``list`` and
|
||||
tree_flatten
|
||||
tree_unflatten
|
||||
tree_map
|
||||
tree_map_with_path
|
||||
tree_reduce
|
||||
|
@@ -18,7 +18,7 @@ describe below.
|
||||
Transforming Compute Graphs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Lazy evaluation let's us record a compute graph without actually doing any
|
||||
Lazy evaluation lets us record a compute graph without actually doing any
|
||||
computations. This is useful for function transformations like :func:`grad` and
|
||||
:func:`vmap` and graph optimizations.
|
||||
|
||||
|
@@ -49,7 +49,7 @@ it will be added. You can load the array with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> mx.load("array.npy", a)
|
||||
>>> mx.load("array.npy")
|
||||
array([1], dtype=float32)
|
||||
|
||||
Here's an example of saving several arrays to a single file:
|
||||
|
@@ -9,3 +9,4 @@ build_example(tutorial.cpp)
|
||||
build_example(linear_regression.cpp)
|
||||
build_example(logistic_regression.cpp)
|
||||
build_example(metal_capture.cpp)
|
||||
build_example(distributed.cpp)
|
||||
|
22
examples/cpp/distributed.cpp
Normal file
22
examples/cpp/distributed.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
if (!distributed::is_available()) {
|
||||
std::cout << "No communication backend found" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto global_group = distributed::init();
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_reduce_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
@@ -8,9 +8,10 @@
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
// Enable the MLX_METAL_DEBUG CMake option to enhance the capture with groups,
|
||||
// labels, etc.
|
||||
assert(metal::start_capture());
|
||||
// To use Metal debugging and profiling:
|
||||
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
||||
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
||||
metal::start_capture("mlx_trace.gputrace");
|
||||
|
||||
// Start at index two because the default GPU and CPU streams have indices
|
||||
// zero and one, respectively. This naming matches the label assigned to each
|
||||
|
@@ -89,8 +89,8 @@ void automatic_differentiation() {
|
||||
// dfdx is 2 * x
|
||||
|
||||
// Get the second derivative by composing grad with grad
|
||||
auto df2dx2 = grad(grad(fn))(x);
|
||||
// df2dx2 is 2
|
||||
auto d2fdx2 = grad(grad(fn))(x);
|
||||
// d2fdx2 is 2
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
@@ -1,6 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(mlx_sample_extensions LANGUAGES CXX)
|
||||
project(_ext LANGUAGES CXX)
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
@@ -11,8 +11,12 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
||||
|
||||
# ----------------------------- Dependencies -----------------------------
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
find_package(Python COMPONENTS Interpreter Development)
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
# ----------------------------- Extensions -----------------------------
|
||||
|
||||
@@ -38,7 +42,6 @@ target_link_libraries(mlx_ext PUBLIC mlx)
|
||||
|
||||
# Build metallib
|
||||
if(MLX_BUILD_METAL)
|
||||
|
||||
mlx_build_metallib(
|
||||
TARGET mlx_ext_metallib
|
||||
TITLE mlx_ext
|
||||
@@ -54,13 +57,15 @@ if(MLX_BUILD_METAL)
|
||||
|
||||
endif()
|
||||
|
||||
# ----------------------------- Pybind -----------------------------
|
||||
pybind11_add_module(
|
||||
mlx_sample_extensions
|
||||
# ----------------------------- Python Bindings -----------------------------
|
||||
nanobind_add_module(
|
||||
_ext
|
||||
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||
NB_DOMAIN mlx
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
)
|
||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
|
||||
endif()
|
||||
|
24
examples/extensions/README.md
Normal file
24
examples/extensions/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
|
||||
## Build
|
||||
|
||||
```
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For faster builds during development, you can also pre-install the requirements:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And then run:
|
||||
|
||||
```
|
||||
python setup.py build_ext -j8 --inplace
|
||||
```
|
||||
|
||||
## Test
|
||||
|
||||
```
|
||||
python test.py
|
||||
```
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
@@ -43,7 +43,7 @@ array axpby(
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = is_floating_point(promoted_dtype)
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
|
||||
@@ -106,12 +106,12 @@ void axpby_impl(
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out_arr) {
|
||||
auto out = out_arr[0];
|
||||
std::vector<array>& outputs) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
@@ -150,11 +150,7 @@ void axpby_impl_accelerate(
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
||||
y.data_size(),
|
||||
y.strides(),
|
||||
y.flags());
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
@@ -180,11 +176,11 @@ void axpby_impl_accelerate(
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outarr) {
|
||||
auto out = outarr[0];
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
@@ -195,7 +191,7 @@ void Axpby::eval_cpu(
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, outarr);
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#else // Accelerate not available
|
||||
@@ -203,8 +199,8 @@ void Axpby::eval_cpu(
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out) {
|
||||
eval(inputs, out);
|
||||
const std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -218,12 +214,12 @@ void Axpby::eval_cpu(
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outarr) {
|
||||
std::vector<array>& outputs) {
|
||||
// Prepare inputs
|
||||
auto out = outarr[0];
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Each primitive carries the stream it should execute on
|
||||
// and each stream carries its device identifiers
|
||||
@@ -261,7 +257,7 @@ void Axpby::eval_gpu(
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
@@ -270,11 +266,11 @@ void Axpby::eval_gpu(
|
||||
size_t nelem = out.size();
|
||||
|
||||
// Encode input arrays to kernel
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, y, 1);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(y, 1);
|
||||
|
||||
// Encode output arrays to kernel
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
@@ -300,7 +296,7 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#else // Metal is not available
|
||||
@@ -372,4 +368,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
|
||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
@@ -33,7 +33,7 @@ array axpby(
|
||||
class Axpby : public Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta){};
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
@@ -42,9 +42,9 @@ class Axpby : public Primitive {
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out)
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out)
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
@@ -83,7 +83,7 @@ class Axpby : public Primitive {
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& out);
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
@@ -19,7 +19,7 @@ template <typename T>
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||
out[index] =
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||
}
|
||||
|
||||
@@ -31,30 +31,30 @@ template <typename T>
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] =
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
||||
}
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] \
|
||||
[[kernel]] void axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("axpby_contiguous_" #type_name)]] \
|
||||
[[kernel]] void axpby_contiguous<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
|
||||
axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
|
||||
axpby_contiguous<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
|
@@ -1,31 +1,31 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
|
||||
#include "axpby/axpby.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
py::pos_only(),
|
||||
"alpha"_a,
|
||||
"beta"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
R"(
|
||||
Scale and sum two vectors element-wise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
Inputs are upcasted to floats if needed
|
||||
|
||||
@@ -37,5 +37,5 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
|
||||
Returns:
|
||||
array: ``alpha * x + beta * y``
|
||||
)pbdoc");
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
@@ -2,4 +2,4 @@
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .mlx_sample_extensions import *
|
||||
from ._ext import axpby
|
||||
|
@@ -1,3 +1,8 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.24",
|
||||
"mlx>=0.9.0",
|
||||
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
4
examples/extensions/requirements.txt
Normal file
4
examples/extensions/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
mlx>=0.9.0
|
||||
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
@@ -9,11 +9,11 @@ if __name__ == "__main__":
|
||||
name="mlx_sample_extensions",
|
||||
version="0.0.0",
|
||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages=["mlx_sample_extensions"],
|
||||
package_dir={"": "."},
|
||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||
extras_require={"dev": []},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.8",
|
||||
)
|
||||
|
10
examples/extensions/test.py
Normal file
10
examples/extensions/test.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import mlx.core as mx
|
||||
from mlx_sample_extensions import axpby
|
||||
|
||||
a = mx.ones((3, 4))
|
||||
b = mx.ones((3, 4))
|
||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
@@ -19,11 +19,17 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||
)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
if (MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if (MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
else()
|
||||
elseif(MLX_BUILD_CPU)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
|
@@ -14,7 +14,7 @@ class Buffer {
|
||||
void* ptr_;
|
||||
|
||||
public:
|
||||
Buffer(void* ptr) : ptr_(ptr){};
|
||||
Buffer(void* ptr) : ptr_(ptr) {};
|
||||
|
||||
// Get the raw data pointer from the buffer
|
||||
void* raw_ptr();
|
||||
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "mlx/array.h"
|
||||
@@ -93,7 +92,13 @@ void array::detach() {
|
||||
}
|
||||
|
||||
void array::eval() {
|
||||
mlx::core::eval({*this});
|
||||
// Ensure the array is ready to be read
|
||||
if (status() == Status::scheduled) {
|
||||
event().wait();
|
||||
set_status(Status::available);
|
||||
} else if (status() == Status::unscheduled) {
|
||||
mlx::core::eval({*this});
|
||||
}
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
@@ -161,6 +166,39 @@ void array::move_shared_buffer(array other) {
|
||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
array::~array() {
|
||||
if (array_desc_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore arrays that will be detached
|
||||
if (status() != array::Status::unscheduled) {
|
||||
return;
|
||||
}
|
||||
// Break circular reference for non-detached arrays with siblings
|
||||
if (auto n = siblings().size(); n > 0) {
|
||||
bool do_detach = true;
|
||||
// If all siblings have siblings.size() references except
|
||||
// the one we are currently destroying (which has siblings.size() + 1)
|
||||
// then there are no more external references
|
||||
do_detach &= (array_desc_.use_count() == (n + 1));
|
||||
for (auto& s : siblings()) {
|
||||
do_detach &= (s.array_desc_.use_count() == n);
|
||||
if (!do_detach) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (do_detach) {
|
||||
for (auto& s : siblings()) {
|
||||
for (auto& ss : s.siblings()) {
|
||||
ss.array_desc_ = nullptr;
|
||||
}
|
||||
s.array_desc_->siblings.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void array::ArrayDesc::init() {
|
||||
strides.resize(shape.size());
|
||||
size = 1;
|
||||
@@ -174,7 +212,7 @@ void array::ArrayDesc::init() {
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||
: shape(std::move(shape)), dtype(dtype) {
|
||||
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
|
||||
init();
|
||||
}
|
||||
|
||||
@@ -185,11 +223,42 @@ array::ArrayDesc::ArrayDesc(
|
||||
std::vector<array> inputs)
|
||||
: shape(std::move(shape)),
|
||||
dtype(dtype),
|
||||
status(Status::unscheduled),
|
||||
primitive(std::move(primitive)),
|
||||
inputs(std::move(inputs)) {
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayDesc::~ArrayDesc() {
|
||||
// When an array description is destroyed it will delete a bunch of arrays
|
||||
// that may also destory their corresponding descriptions and so on and so
|
||||
// forth.
|
||||
//
|
||||
// This calls recursively the destructor and can result in stack overflow, we
|
||||
// instead put them in a vector and destroy them one at a time resulting in a
|
||||
// max stack depth of 2.
|
||||
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
|
||||
|
||||
for (array& a : inputs) {
|
||||
if (a.array_desc_.use_count() == 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
}
|
||||
|
||||
while (!for_deletion.empty()) {
|
||||
// top is going to be deleted at the end of the block *after* the arrays
|
||||
// with inputs have been moved into the vector
|
||||
auto top = std::move(for_deletion.back());
|
||||
for_deletion.pop_back();
|
||||
|
||||
for (array& a : top->inputs) {
|
||||
if (a.array_desc_.use_count() == 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
: arr(arr), idx(idx) {
|
||||
if (arr.ndim() == 0) {
|
||||
|
64
mlx/array.h
64
mlx/array.h
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/event.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -113,6 +114,15 @@ class array {
|
||||
return array_desc_->strides;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the stride of the corresponding dimension.
|
||||
*
|
||||
* This function supports negative indexing and provides
|
||||
* bounds checking. */
|
||||
size_t strides(int dim) const {
|
||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||
};
|
||||
|
||||
/** Get the arrays data type. */
|
||||
Dtype dtype() const {
|
||||
return array_desc_->dtype;
|
||||
@@ -199,7 +209,7 @@ class array {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
: buffer(buffer), d(d){};
|
||||
: buffer(buffer), d(d) {};
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
@@ -251,22 +261,16 @@ class array {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
|
||||
/** The array's siblings. */
|
||||
std::vector<array>& siblings() {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
array_desc_->position = position;
|
||||
}
|
||||
|
||||
/** The i-th output of the array's primitive. */
|
||||
const array& output(int i) const {
|
||||
if (i == array_desc_->position) {
|
||||
return *this;
|
||||
} else if (i < array_desc_->position) {
|
||||
return siblings()[i];
|
||||
} else {
|
||||
return siblings()[i + 1];
|
||||
}
|
||||
};
|
||||
|
||||
/** The outputs of the array's primitive (i.e. this array and
|
||||
* its siblings) in the order the primitive expects. */
|
||||
std::vector<array> outputs() const {
|
||||
@@ -315,9 +319,27 @@ class array {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
};
|
||||
|
||||
// Check if the array has been evaluated
|
||||
bool is_evaled() const {
|
||||
return array_desc_->data != nullptr;
|
||||
enum Status { unscheduled, scheduled, available };
|
||||
|
||||
bool is_available() const {
|
||||
return status() == Status::available;
|
||||
}
|
||||
const Status status() const {
|
||||
return array_desc_->status;
|
||||
}
|
||||
|
||||
void set_status(Status s) const {
|
||||
array_desc_->status = s;
|
||||
}
|
||||
|
||||
// Get the array's shared event
|
||||
Event& event() const {
|
||||
return array_desc_->event;
|
||||
}
|
||||
|
||||
// Attach an event to a not yet evaluated array
|
||||
void attach_event(Event e) const {
|
||||
array_desc_->event = std::move(e);
|
||||
}
|
||||
|
||||
// Mark the array as a tracer array (true) or not.
|
||||
@@ -358,6 +380,8 @@ class array {
|
||||
array_desc_ = other.array_desc_;
|
||||
}
|
||||
|
||||
~array();
|
||||
|
||||
private:
|
||||
// Initialize the arrays data
|
||||
template <typename It>
|
||||
@@ -370,6 +394,11 @@ class array {
|
||||
Dtype dtype;
|
||||
std::shared_ptr<Primitive> primitive;
|
||||
|
||||
Status status;
|
||||
|
||||
// An event on the array used for synchronization
|
||||
Event event;
|
||||
|
||||
// Indicates an array is being used in a graph transform
|
||||
// and should not be detached from the graph
|
||||
bool is_tracer{false};
|
||||
@@ -404,6 +433,8 @@ class array {
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs);
|
||||
|
||||
~ArrayDesc();
|
||||
|
||||
private:
|
||||
// Initialize size, strides, and other metadata
|
||||
void init();
|
||||
@@ -468,10 +499,11 @@ T array::item() const {
|
||||
if (size() != 1) {
|
||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||
}
|
||||
if (!is_evaled()) {
|
||||
if (status() == Status::unscheduled) {
|
||||
throw std::invalid_argument(
|
||||
"item() const can only be called on evaled arrays");
|
||||
}
|
||||
const_cast<array*>(this)->eval();
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
@@ -196,6 +196,40 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
||||
return matmul_bnns_general(a_pre, b_pre, out);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void mask_matrix(
|
||||
T* data,
|
||||
const bool* mask,
|
||||
int tile_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str) {
|
||||
int tX = (X + tile_size - 1) / tile_size;
|
||||
int tY = (Y + tile_size - 1) / tile_size;
|
||||
|
||||
for (int i = 0; i < tX; i++) {
|
||||
for (int j = 0; j < tY; j++) {
|
||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
||||
if (!do_mask) {
|
||||
int loc_x = i * tile_size;
|
||||
int loc_y = j * tile_size;
|
||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||
|
||||
int size_x = std::min(tile_size, X - loc_x);
|
||||
int size_y = std::min(tile_size, Y - loc_y);
|
||||
for (int ii = 0; ii < size_x; ii++) {
|
||||
for (int jj = 0; jj < size_y; jj++) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
@@ -31,9 +31,11 @@ DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
@@ -45,6 +47,8 @@ DEFAULT(ErfInv)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Less)
|
||||
@@ -76,6 +80,7 @@ DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
@@ -191,6 +196,26 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
if (a.is_donatable()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
int size = a.data_size();
|
||||
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -310,6 +335,19 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpm1f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
@@ -201,7 +201,7 @@ struct NeonFp16SimdOps {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT, typename Ops, int N>
|
||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
|
||||
@@ -218,13 +218,21 @@ void softmax(const array& in, array& out) {
|
||||
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
|
||||
VT vals;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vals = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vals[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vmaximum = ops.max(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
T maximum = ops.reduce_max(vmaximum);
|
||||
AccT maximum = ops.reduce_max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, *current_in_ptr);
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
@@ -234,18 +242,29 @@ void softmax(const array& in, array& out) {
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
|
||||
ops.store(current_out_ptr, vexp);
|
||||
*(VT*)current_out_ptr = vexp;
|
||||
VT vexp;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vexp = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vexp = ops.exp(ops.sub(vexp, maximum));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = ops.add(vnormalizer, vexp);
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
T normalizer = ops.reduce_add(vnormalizer);
|
||||
AccT normalizer = ops.reduce_add(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
T _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = _exp;
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
@@ -254,14 +273,33 @@ void softmax(const array& in, array& out) {
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
} else {
|
||||
VT vexp;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
current_out_ptr[i] = vexp[i];
|
||||
}
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*current_out_ptr *= normalizer;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
@@ -308,15 +346,29 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
|
||||
in, out);
|
||||
softmax<
|
||||
float,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
softmax<
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
if (precise_) {
|
||||
softmax<
|
||||
float16_t,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
} else {
|
||||
softmax<
|
||||
float16_t,
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
eval(inputs, out);
|
||||
|
@@ -37,10 +37,12 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
@@ -54,6 +56,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
)
|
||||
|
||||
|
@@ -236,4 +236,82 @@ void Subtract::eval(const std::vector<array>& inputs, array& out) {
|
||||
binary(a, b, out, detail::Subtract());
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto dispatch_type = [&a, &b, &out](auto op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[BitwiseBinary::eval_cpu] Type not supported");
|
||||
break;
|
||||
}
|
||||
};
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
dispatch_type(detail::BitwiseAnd());
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
dispatch_type(detail::BitwiseOr());
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
dispatch_type(detail::BitwiseXor());
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
dispatch_type(detail::LeftShift());
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
dispatch_type(detail::RightShift());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[arctan2] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctan2] Cannot compute inverse tangent for arrays"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,6 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
101
mlx/backend/common/cholesky.cpp
Normal file
101
mlx/backend/common/cholesky.cpp
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Delegate to the Cholesky factorization taking into account differences in
|
||||
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
|
||||
int spotrf_wrapper(char uplo, float* matrix, int N) {
|
||||
int info;
|
||||
|
||||
#ifdef LAPACK_FORTRAN_STRLEN_END
|
||||
spotrf_(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info,
|
||||
/* uplo_len = */ static_cast<size_t>(1));
|
||||
#else
|
||||
spotrf_(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
#endif
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||
// the matrix should be symmetric:
|
||||
// (A)ᵀ = A
|
||||
// and that a column-major lower triangular matrix is a row-major upper
|
||||
// triangular matrix, so uplo is the opposite of what we would expect from
|
||||
// upper
|
||||
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
|
||||
// The decomposition is computed in place, so just copy the input to the
|
||||
// output.
|
||||
copy(
|
||||
a,
|
||||
factor,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
float* matrix = factor.data<float>();
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info = spotrf_wrapper(uplo, matrix, N);
|
||||
|
||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||
// because throwing an error would result in a crash. If we figure out how
|
||||
// to catch errors from the implementation we should throw.
|
||||
if (info < 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[cholesky] Cholesky decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Zero out the upper/lower triangle while advancing the pointer to the
|
||||
// next matrix at the same time.
|
||||
for (int row = 0; row < N; row++) {
|
||||
if (upper) {
|
||||
std::fill(matrix, matrix + row, 0);
|
||||
} else {
|
||||
std::fill(matrix + row + 1, matrix + N, 0);
|
||||
}
|
||||
matrix += N;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Cholesky::eval] only supports float32.");
|
||||
}
|
||||
cholesky_impl(inputs[0], output, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
347
mlx/backend/common/common.cpp
Normal file
347
mlx/backend/common/common.cpp
Normal file
@@ -0,0 +1,347 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (!in.flags().row_contiguous) {
|
||||
// Just ensuring that inputs[0] came from the ops which would ensure the
|
||||
// input is row contiguous.
|
||||
throw std::runtime_error(
|
||||
"AsStrided must be used with row contiguous arrays only.");
|
||||
}
|
||||
|
||||
// Compute the flags given the shape and strides
|
||||
bool row_contiguous = true, col_contiguous = true;
|
||||
size_t r = 1, c = 1;
|
||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||
r *= shape_[i];
|
||||
c *= shape_[j];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
// TODO: Compute the contiguous flag in a better way cause now we are
|
||||
// unnecessarily strict.
|
||||
flags.contiguous = row_contiguous || col_contiguous;
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
|
||||
// There is no easy way to compute the actual data size so we use out.size().
|
||||
// The contiguous flag will almost certainly not be set so no code should
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void CustomVJP::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
i++, j++) {
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
}
|
||||
}
|
||||
|
||||
void Depends::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
numel *= inputs[0].shape(ax);
|
||||
}
|
||||
|
||||
if (inverted_) {
|
||||
numel = 1.0 / numel;
|
||||
}
|
||||
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
*out.data<bool>() = static_cast<bool>(numel);
|
||||
break;
|
||||
case uint8:
|
||||
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
|
||||
break;
|
||||
case uint16:
|
||||
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
|
||||
break;
|
||||
case uint32:
|
||||
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
|
||||
break;
|
||||
case uint64:
|
||||
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
|
||||
break;
|
||||
case int8:
|
||||
*out.data<int8_t>() = static_cast<int8_t>(numel);
|
||||
break;
|
||||
case int16:
|
||||
*out.data<int16_t>() = static_cast<int16_t>(numel);
|
||||
break;
|
||||
case int32:
|
||||
*out.data<int32_t>() = static_cast<int32_t>(numel);
|
||||
break;
|
||||
case int64:
|
||||
*out.data<int64_t>() = static_cast<int64_t>(numel);
|
||||
break;
|
||||
case float16:
|
||||
*out.data<float16_t>() = static_cast<float16_t>(numel);
|
||||
break;
|
||||
case float32:
|
||||
*out.data<float>() = static_cast<float>(numel);
|
||||
break;
|
||||
case bfloat16:
|
||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||
break;
|
||||
case complex64:
|
||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
}
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
auto [shape, _strides] = collapse_contiguous_dims(in);
|
||||
auto& strides = _strides[0];
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
std::vector<size_t> out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
int N = out.shape(i);
|
||||
if (j < shape.size() && shape[j] % N == 0) {
|
||||
shape[j] /= N;
|
||||
out_strides.push_back(shape[j] * strides[j]);
|
||||
j += (shape[j] == 1);
|
||||
} else if (N == 1) {
|
||||
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
|
||||
out_strides.push_back(out_strides.back());
|
||||
} else {
|
||||
copy_necessary = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.row_contiguous) {
|
||||
// For row contiguous reshapes:
|
||||
// - Shallow copy the buffer
|
||||
// - If reshaping into a vector (all singleton dimensions except one) it
|
||||
// becomes col contiguous again.
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
auto compute_new_flags = [](const auto& shape,
|
||||
const auto& strides,
|
||||
size_t in_data_size,
|
||||
auto flags) {
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.row_contiguous = true;
|
||||
flags.col_contiguous = true;
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
data_size *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in_data_size) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
return std::pair<decltype(flags), size_t>{flags, data_size};
|
||||
};
|
||||
|
||||
std::vector<int> indices(1, 0);
|
||||
indices.insert(indices.end(), indices_.begin(), indices_.end());
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
size_t offset = indices[i] * in.strides()[axis_];
|
||||
auto [new_flags, data_size] = compute_new_flags(
|
||||
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
|
||||
outputs[i].copy_shared_buffer(
|
||||
in, in.strides(), new_flags, data_size, offset);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
|
||||
copy_needed |= strides_[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void Slice::shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
}
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::vector<size_t> out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||
f_stride *= out.shape(i);
|
||||
flags.row_contiguous &=
|
||||
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -126,4 +126,102 @@ std::string build_lib_name(
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape) {
|
||||
bool contiguous = true;
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
bool all_col_contig = true;
|
||||
int non_scalar_inputs = 0;
|
||||
for (const auto& x : inputs) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
non_scalar_inputs++;
|
||||
bool shape_eq = x.shape() == shape;
|
||||
all_contig &= (x.flags().contiguous && shape_eq);
|
||||
all_row_contig &= (x.flags().row_contiguous && shape_eq);
|
||||
all_col_contig &= (x.flags().col_contiguous && shape_eq);
|
||||
}
|
||||
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 1 && !all_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 0 && !shape.empty()) {
|
||||
contiguous = false;
|
||||
}
|
||||
return contiguous;
|
||||
}
|
||||
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers /* = false */) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Correct size
|
||||
// - Not a scalar
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o++].move_shared_buffer(in);
|
||||
} else {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
strides = in.strides();
|
||||
flags = in.flags();
|
||||
data_size = in.data_size();
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
}
|
||||
} else {
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
} else {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
}
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -53,4 +53,18 @@ inline bool is_scalar(const array& x) {
|
||||
return x.ndim() == 0;
|
||||
}
|
||||
|
||||
// Check if we can use a contiguous operation given inputs and the output shape
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape);
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers = false);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -52,8 +52,25 @@ void* compile(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string kernel_file_name;
|
||||
|
||||
// Deal with long kernel names. Maximum length for files on macOS is 255
|
||||
// characters. Clip file name with a little extra room and append a 16
|
||||
// character hash.
|
||||
constexpr int max_file_name_length = 245;
|
||||
if (kernel_name.size() > max_file_name_length) {
|
||||
std::ostringstream file_name;
|
||||
file_name
|
||||
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
|
||||
auto file_id = std::hash<std::string>{}(kernel_name);
|
||||
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
|
||||
kernel_file_name = file_name.str();
|
||||
} else {
|
||||
kernel_file_name = kernel_name;
|
||||
}
|
||||
|
||||
std::ostringstream shared_lib_name;
|
||||
shared_lib_name << "lib" << kernel_name << ".so";
|
||||
shared_lib_name << "lib" << kernel_file_name << ".so";
|
||||
auto shared_lib_path = get_temp_file(shared_lib_name.str());
|
||||
bool lib_exists = false;
|
||||
{
|
||||
@@ -64,7 +81,7 @@ void* compile(
|
||||
if (!lib_exists) {
|
||||
// Open source file and write source code to it
|
||||
std::ostringstream source_file_name;
|
||||
source_file_name << kernel_name << ".cpp";
|
||||
source_file_name << kernel_file_name << ".cpp";
|
||||
auto source_file_path = get_temp_file(source_file_name.str());
|
||||
|
||||
std::ofstream source_file(source_file_path);
|
||||
@@ -248,28 +265,7 @@ void Compiled::eval_cpu(
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
{
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
bool all_col_contig = true;
|
||||
int non_scalar_inputs = 0;
|
||||
for (auto& x : inputs) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
non_scalar_inputs++;
|
||||
bool shape_eq = x.shape() == shape;
|
||||
all_contig &= (x.flags().contiguous && shape_eq);
|
||||
all_row_contig &= (x.flags().row_contiguous && shape_eq);
|
||||
all_col_contig &= (x.flags().col_contiguous && shape_eq);
|
||||
}
|
||||
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 1 && !all_contig) {
|
||||
contiguous = false;
|
||||
}
|
||||
}
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
std::vector<void*> args;
|
||||
@@ -342,58 +338,8 @@ void Compiled::eval_cpu(
|
||||
fn_ptr = compile(kernel_name, kernel.str());
|
||||
}
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
strides = in.strides();
|
||||
flags = in.flags();
|
||||
data_size = in.data_size();
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
}
|
||||
} else {
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
|
@@ -38,11 +38,15 @@ void slow_conv_1D(
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int C = in.shape(2); // Input channels
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(2); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
|
||||
const int groups = C / wt.shape(2);
|
||||
const int C_per_group = wt.shape(2);
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_H = in.strides()[1];
|
||||
const size_t in_stride_C = in.strides()[2];
|
||||
@@ -57,35 +61,36 @@ void slow_conv_1D(
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int o = 0; o < O; ++o) {
|
||||
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
|
||||
float r = 0.;
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||
|
||||
int wh_flip = flip ? (wH - wh - 1) : wh;
|
||||
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
|
||||
int wh_flip = flip ? (wH - wh - 1) : wh;
|
||||
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
|
||||
|
||||
auto ih_div = std::div(ih, in_dilation[0]);
|
||||
auto ih_div = std::div(ih, in_dilation[0]);
|
||||
|
||||
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(
|
||||
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
|
||||
static_cast<float>(wt_ptr[c * wt_stride_C]);
|
||||
} // c
|
||||
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
||||
r += static_cast<float>(
|
||||
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
|
||||
static_cast<float>(wt_ptr[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
|
||||
} // ih check
|
||||
} // wh
|
||||
} // ih check
|
||||
} // wh
|
||||
|
||||
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
|
||||
} // o
|
||||
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
|
||||
} // o
|
||||
} // g
|
||||
} // oh
|
||||
|
||||
in_ptr += in_stride_N;
|
||||
out_ptr += out_stride_N;
|
||||
|
||||
} // n
|
||||
}
|
||||
|
||||
@@ -106,13 +111,17 @@ void slow_conv_2D(
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
||||
const int C = in.shape(3); // In channels
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int oW = out.shape(2); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(3); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
const int wW = wt.shape(2); // Weight spatial dim
|
||||
|
||||
const int groups = C / wt.shape(3);
|
||||
const int C_per_group = wt.shape(3);
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_H = in.strides()[1];
|
||||
const size_t in_stride_W = in.strides()[2];
|
||||
@@ -136,33 +145,35 @@ void slow_conv_2D(
|
||||
int ih_base = oh * wt_strides[0] - padding[0];
|
||||
int iw_base = ow * wt_strides[1] - padding[1];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
} // g
|
||||
};
|
||||
|
||||
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||
@@ -214,41 +225,43 @@ void slow_conv_2D(
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||
++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
|
||||
} // ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
} // ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
} // g
|
||||
};
|
||||
|
||||
int oH_border_0 = 0;
|
||||
@@ -305,6 +318,296 @@ void slow_conv_2D(
|
||||
} // n
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void slow_conv_3D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const T* st_wt_ptr = wt.data<T>();
|
||||
const T* st_in_ptr = in.data<T>();
|
||||
T* st_out_ptr = out.data<T>();
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
||||
const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim
|
||||
const int oD = out.shape(1); // Output spatial dim
|
||||
const int oH = out.shape(2); // Output spatial dim
|
||||
const int oW = out.shape(3); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(4); // In channels
|
||||
const int wD = wt.shape(1); // Weight spatial dim
|
||||
const int wH = wt.shape(2); // Weight spatial dim
|
||||
const int wW = wt.shape(3); // Weight spatial dim
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_D = in.strides()[1];
|
||||
const size_t in_stride_H = in.strides()[2];
|
||||
const size_t in_stride_W = in.strides()[3];
|
||||
const size_t in_stride_C = in.strides()[4];
|
||||
|
||||
const size_t wt_stride_O = wt.strides()[0];
|
||||
const size_t wt_stride_D = wt.strides()[1];
|
||||
const size_t wt_stride_H = wt.strides()[2];
|
||||
const size_t wt_stride_W = wt.strides()[3];
|
||||
const size_t wt_stride_C = wt.strides()[4];
|
||||
|
||||
const size_t out_stride_N = out.strides()[0];
|
||||
const size_t out_stride_D = out.strides()[1];
|
||||
const size_t out_stride_H = out.strides()[2];
|
||||
const size_t out_stride_W = out.strides()[3];
|
||||
const size_t out_stride_O = out.strides()[4];
|
||||
|
||||
bool is_idil_one =
|
||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1;
|
||||
|
||||
auto pt_conv_no_checks = [&](const T* in_ptr,
|
||||
const T* wt_ptr,
|
||||
T* out_ptr,
|
||||
int od,
|
||||
int oh,
|
||||
int ow) {
|
||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||
int id_base = od * wt_strides[0] - padding[0];
|
||||
int ih_base = oh * wt_strides[1] - padding[1];
|
||||
int iw_base = ow * wt_strides[2] - padding[2];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wd = 0; wd < wD; ++wd) {
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wd_flip = flip ? wD - wd - 1 : wd;
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int id = id_base + wd_flip * wt_dilation[0];
|
||||
int ih = ih_base + wh_flip * wt_dilation[1];
|
||||
int iw = iw_base + ww_flip * wt_dilation[2];
|
||||
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
|
||||
} // ww
|
||||
} // wh
|
||||
} // wd
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
};
|
||||
|
||||
int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||
int jump_h = flip ? -wt_dilation[1] : wt_dilation[1];
|
||||
int jump_w = flip ? -wt_dilation[2] : wt_dilation[2];
|
||||
|
||||
int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0);
|
||||
int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0);
|
||||
int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0);
|
||||
|
||||
int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
||||
int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||
int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2];
|
||||
|
||||
int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||
int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||
int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2];
|
||||
|
||||
std::vector<int> base_d(f_out_jump_d);
|
||||
std::vector<int> base_h(f_out_jump_h);
|
||||
std::vector<int> base_w(f_out_jump_w);
|
||||
|
||||
for (int i = 0; i < f_out_jump_d; ++i) {
|
||||
int id_loop = i * wt_strides[0] - padding[0] + init_d;
|
||||
|
||||
int wd_base = 0;
|
||||
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
||||
wd_base++;
|
||||
id_loop += jump_d;
|
||||
}
|
||||
|
||||
base_d[i] = wd_base;
|
||||
}
|
||||
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
|
||||
|
||||
int wh_base = 0;
|
||||
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
||||
wh_base++;
|
||||
ih_loop += jump_h;
|
||||
}
|
||||
|
||||
base_h[i] = wh_base;
|
||||
}
|
||||
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
|
||||
|
||||
int ww_base = 0;
|
||||
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
||||
ww_base++;
|
||||
iw_loop += jump_w;
|
||||
}
|
||||
|
||||
base_w[j] = ww_base;
|
||||
}
|
||||
|
||||
auto pt_conv_all_checks = [&](const T* in_ptr,
|
||||
const T* wt_ptr,
|
||||
T* out_ptr,
|
||||
int od,
|
||||
int oh,
|
||||
int ow) {
|
||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||
|
||||
int id_base = od * wt_strides[0] - padding[0];
|
||||
int ih_base = oh * wt_strides[1] - padding[1];
|
||||
int iw_base = ow * wt_strides[2] - padding[2];
|
||||
|
||||
int wd_base = base_d[od % f_out_jump_d];
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wd_flip = flip ? wD - wd - 1 : wd;
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int id = id_base + wd_flip * wt_dilation[0];
|
||||
int ih = ih_base + wh_flip * wt_dilation[1];
|
||||
int iw = iw_base + ww_flip * wt_dilation[2];
|
||||
|
||||
if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
|
||||
iw < iW) {
|
||||
const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
|
||||
wh * wt_stride_H + ww * wt_stride_W;
|
||||
|
||||
int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;
|
||||
|
||||
const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
|
||||
ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
|
||||
} // iD, ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
} // wd
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
};
|
||||
|
||||
int oD_border_0 = 0;
|
||||
int oD_border_1 =
|
||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
|
||||
int oD_border_2 = std::max(
|
||||
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||
int oD_border_3 = oD;
|
||||
|
||||
int oH_border_0 = 0;
|
||||
int oH_border_1 =
|
||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
|
||||
int oH_border_2 = std::max(
|
||||
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||
int oH_border_3 = oH;
|
||||
|
||||
int oW_border_0 = 0;
|
||||
int oW_border_1 =
|
||||
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
|
||||
int oW_border_2 = std::max(
|
||||
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||
int oW_border_3 = oW;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Case 1: od might put us out of bounds
|
||||
for (int od = oD_border_0; od < oD_border_1; ++od) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
} // od
|
||||
|
||||
// Case 2: od in bounds
|
||||
for (int od = oD_border_1; od < oD_border_2; ++od) {
|
||||
// Case 2.1: oh might put us out of bounds
|
||||
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
|
||||
// Case 2.2: oh in bounds
|
||||
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
||||
// Case 2.2.1: ow might put us out of bounds
|
||||
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
|
||||
// Case 2.2.2: ow in bounds
|
||||
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
||||
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
|
||||
// Case 2.2.3: ow might put us out of bounds
|
||||
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
|
||||
// Case 2.3: oh might put us out of bounds
|
||||
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
} // od
|
||||
|
||||
// Case 3: od might put us out of bounds
|
||||
for (int od = oD_border_2; od < oD_border_3; ++od) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int ow = 0; ow < oW; ++ow) {
|
||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
|
||||
} // ow
|
||||
} // oh
|
||||
} // od
|
||||
|
||||
st_in_ptr += in_stride_N;
|
||||
st_out_ptr += out_stride_N;
|
||||
|
||||
} // n
|
||||
}
|
||||
|
||||
void dispatch_slow_conv_1D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
@@ -353,6 +656,30 @@ void dispatch_slow_conv_2D(
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_slow_conv_3D(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
if (in.dtype() == float32) {
|
||||
return slow_conv_3D<float>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == float16) {
|
||||
return slow_conv_3D<float16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == bfloat16) {
|
||||
return slow_conv_3D<bfloat16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval] got unsupported data type.");
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Explicit gemm conv
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -366,11 +693,15 @@ void explicit_gemm_conv_1D_cpu(
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int C = in.shape(2); // Input channels
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(2); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
|
||||
const int groups = C / wt.shape(2);
|
||||
const int C_per_group = wt.shape(2);
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
@@ -402,6 +733,11 @@ void explicit_gemm_conv_1D_cpu(
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2]};
|
||||
auto flags = in_padded.flags();
|
||||
if (groups > 1) {
|
||||
// Transpose the last two dimensions for grouped convolutions
|
||||
std::swap(strided_shape[2], strided_shape[3]);
|
||||
std::swap(strided_strides[2], strided_strides[3]);
|
||||
}
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
@@ -416,7 +752,19 @@ void explicit_gemm_conv_1D_cpu(
|
||||
auto gemm_wt = wt;
|
||||
auto gemm_out = out;
|
||||
|
||||
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
if (groups > 1) {
|
||||
// Transpose the last two dimensions for grouped convolutions
|
||||
array wt_transpose(
|
||||
{wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
|
||||
wt_transpose.copy_shared_buffer(
|
||||
wt,
|
||||
{wt.strides(0), wt.strides(2), wt.strides(1)},
|
||||
wt.flags(),
|
||||
wt.size(),
|
||||
0);
|
||||
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
|
||||
copy(wt_transpose, gemm_wt, CopyType::General);
|
||||
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
@@ -428,27 +776,29 @@ void explicit_gemm_conv_1D_cpu(
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
}
|
||||
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O, // N
|
||||
strided_reshape[1], // K
|
||||
1.0f, // alpha
|
||||
in_strided.data<float>(),
|
||||
strided_reshape[1], // lda
|
||||
gemm_wt.data<float>(),
|
||||
strided_reshape[1], // ldb
|
||||
0.0f, // beta
|
||||
gemm_out.data<float>(),
|
||||
O // ldc
|
||||
);
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O_per_group, // N
|
||||
C_per_group * wH, // K
|
||||
1.0f, // alpha
|
||||
in_strided.data<float>() + g * C_per_group * wH, // A
|
||||
wH * C, // lda
|
||||
gemm_wt.data<float>() + g * O_per_group * C_per_group * wH, // B
|
||||
wH * C_per_group, // ldb
|
||||
0.0f, // beta
|
||||
gemm_out.data<float>() + g * O_per_group, // C
|
||||
O // ldc
|
||||
);
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy(gemm_out, out, CopyType::Vector);
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy(gemm_out, out, CopyType::Vector);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -554,6 +904,131 @@ void explicit_gemm_conv_2D_cpu(
|
||||
}
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_ND_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const auto iDim = std::vector<int>(
|
||||
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||
const auto oDim = std::vector<int>(
|
||||
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(-1); // In channels
|
||||
const auto wDim = std::vector<int>(
|
||||
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
||||
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape(in.shape().size());
|
||||
padded_shape.front() = N;
|
||||
for (size_t i = 0; i < iDim.size(); i++) {
|
||||
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||
}
|
||||
padded_shape.back() = C;
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = 0;
|
||||
for (size_t i = 0; i < padding.size(); i++) {
|
||||
data_offset += padding[i] * in_padded.strides()[i + 1];
|
||||
}
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
|
||||
strided_shape.front() = N;
|
||||
for (size_t i = 0; i < oDim.size(); i++) {
|
||||
strided_shape[i + 1] = oDim[i];
|
||||
}
|
||||
for (size_t i = 0; i < wDim.size(); i++) {
|
||||
strided_shape[i + 1 + oDim.size()] = wDim[i];
|
||||
}
|
||||
strided_shape.back() = C;
|
||||
|
||||
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
|
||||
strided_strides[0] = in_padded.strides()[0];
|
||||
for (size_t i = 0; i < wt_strides.size(); i++) {
|
||||
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
|
||||
}
|
||||
for (size_t i = 1; i < in_padded.strides().size(); i++) {
|
||||
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
|
||||
}
|
||||
|
||||
auto flags = in_padded.flags();
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N, C};
|
||||
for (const auto& o : oDim) {
|
||||
strided_reshape[0] *= o;
|
||||
}
|
||||
for (const auto& w : wDim) {
|
||||
strided_reshape[1] *= w;
|
||||
}
|
||||
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General);
|
||||
|
||||
// Check wt dtype and prepare
|
||||
auto gemm_wt = wt;
|
||||
auto gemm_out = out;
|
||||
|
||||
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
copy(wt, gemm_wt, ctype);
|
||||
}
|
||||
|
||||
if (out.dtype() != float32) {
|
||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
}
|
||||
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O, // N
|
||||
strided_reshape[1], // K
|
||||
1.0f, // alpha
|
||||
in_strided.data<float>(),
|
||||
strided_reshape[1], // lda
|
||||
gemm_wt.data<float>(),
|
||||
strided_reshape[1], // ldb
|
||||
0.0f, // beta
|
||||
gemm_out.data<float>(),
|
||||
O // ldc
|
||||
);
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy(gemm_out, out, CopyType::Vector);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Conv routing
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -589,6 +1064,19 @@ void conv_2D_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
||||
void conv_3D_cpu(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
return dispatch_slow_conv_3D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -597,8 +1085,20 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
auto& wt = inputs[1];
|
||||
|
||||
// 3D convolution
|
||||
if (in.ndim() == (3 + 2)) {
|
||||
return conv_3D_cpu(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_);
|
||||
}
|
||||
// 2D convolution
|
||||
if (in.ndim() == (2 + 2)) {
|
||||
else if (in.ndim() == (2 + 2)) {
|
||||
return conv_2D_cpu(
|
||||
in,
|
||||
wt,
|
||||
|
@@ -256,7 +256,7 @@ void copy_general_general(
|
||||
}
|
||||
|
||||
int size = std::accumulate(
|
||||
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||
for (int i = 0; i < src.size(); i += size) {
|
||||
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
|
||||
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
|
||||
@@ -272,7 +272,7 @@ inline void copy_general_general(const array& src, array& dst) {
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
copy_single<SrcT, DstT>(src, dst);
|
||||
@@ -281,54 +281,54 @@ void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
||||
copy_vector<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case CopyType::General:
|
||||
copy_general<SrcT, DstT>(src, dst, args...);
|
||||
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
return;
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_general_general<SrcT, DstT>(src, dst, args...);
|
||||
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (dst.dtype()) {
|
||||
case bool_:
|
||||
copy<SrcT, bool>(src, dst, ctype, args...);
|
||||
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<SrcT, uint16_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<SrcT, uint32_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<SrcT, uint64_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<SrcT, int8_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<SrcT, int16_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<SrcT, int32_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<SrcT, int64_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<SrcT, float16_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<SrcT, float>(src, dst, ctype, args...);
|
||||
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<SrcT, complex64_t>(src, dst, ctype, args...);
|
||||
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -338,46 +338,46 @@ inline void copy_inplace_dispatch(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Args... args) {
|
||||
Args&&... args) {
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
copy<bool>(src, dst, ctype, args...);
|
||||
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<uint8_t>(src, dst, ctype, args...);
|
||||
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<uint16_t>(src, dst, ctype, args...);
|
||||
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<uint32_t>(src, dst, ctype, args...);
|
||||
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<uint64_t>(src, dst, ctype, args...);
|
||||
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<int8_t>(src, dst, ctype, args...);
|
||||
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<int16_t>(src, dst, ctype, args...);
|
||||
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<int32_t>(src, dst, ctype, args...);
|
||||
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<int64_t>(src, dst, ctype, args...);
|
||||
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<float16_t>(src, dst, ctype, args...);
|
||||
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype, args...);
|
||||
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype, args...);
|
||||
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<complex64_t>(src, dst, ctype, args...);
|
||||
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@@ -34,6 +34,7 @@ DEFAULT(ArcCosh)
|
||||
DEFAULT(ArcSin)
|
||||
DEFAULT(ArcSinh)
|
||||
DEFAULT(ArcTan)
|
||||
DEFAULT(ArcTan2)
|
||||
DEFAULT(ArcTanh)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
@@ -41,9 +42,13 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
@@ -57,6 +62,7 @@ DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(Expm1)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Full)
|
||||
@@ -107,6 +113,7 @@ DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
|
||||
namespace {
|
||||
|
||||
|
@@ -2,7 +2,6 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
@@ -93,12 +92,4 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
||||
inverse_impl(inputs[0], output);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
return {{linalg::inv(a, stream())}, {ax}};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -11,7 +11,7 @@ GCC=$2
|
||||
SRCDIR=$3
|
||||
CLANG=$4
|
||||
|
||||
if [ $CLANG = "TRUE" ]; then
|
||||
if [ "$CLANG" = "TRUE" ]; then
|
||||
read -r -d '' INCLUDES <<- EOM
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
305
mlx/backend/common/masked_mm.cpp
Normal file
305
mlx/backend/common/masked_mm.cpp
Normal file
@@ -0,0 +1,305 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename mask_t>
|
||||
inline void mask_matrix(
|
||||
T* data,
|
||||
const mask_t* mask,
|
||||
int block_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str,
|
||||
const size_t mask_offset) {
|
||||
int tX = (X + block_size - 1) / block_size;
|
||||
int tY = (Y + block_size - 1) / block_size;
|
||||
|
||||
for (int i = 0; i < tX; i++) {
|
||||
for (int j = 0; j < tY; j++) {
|
||||
mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str];
|
||||
if (do_mask != 1) {
|
||||
int loc_x = i * block_size;
|
||||
int loc_y = j * block_size;
|
||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||
|
||||
int size_x = std::min(block_size, X - loc_x);
|
||||
int size_y = std::min(block_size, Y - loc_y);
|
||||
for (int ii = 0; ii < size_x; ii++) {
|
||||
for (int jj = 0; jj < size_y; jj++) {
|
||||
if constexpr (std::is_same_v<mask_t, bool>) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
} else {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] *= do_mask;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[BlockMaskedMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
auto check_transpose =
|
||||
[](const array& arr, bool do_copy, bool expand_all = false) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(true, sty, arr_copy);
|
||||
}
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
bool has_op_mask = inputs.size() > 3;
|
||||
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
|
||||
auto [a_transposed, lda, a] =
|
||||
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
auto [b_transposed, ldb, b] =
|
||||
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_array = [](const array& mask,
|
||||
float* data,
|
||||
int block_size,
|
||||
int batch_idx,
|
||||
int X,
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str) {
|
||||
size_t mask_offset = elem_to_loc(
|
||||
mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||
mask.shape(),
|
||||
mask.strides());
|
||||
|
||||
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
|
||||
if (mask.dtype() == bool_) {
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask.data<bool>(),
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
X_data_str,
|
||||
Y_data_str,
|
||||
X_mask_str,
|
||||
Y_mask_str,
|
||||
mask_offset);
|
||||
} else {
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask.data<float>(),
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
X_data_str,
|
||||
Y_data_str,
|
||||
X_mask_str,
|
||||
Y_mask_str,
|
||||
mask_offset);
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
|
||||
// Adjust pointer
|
||||
float* ai =
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
float* bi =
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
float* ci = out.data<float>() + M * N * i;
|
||||
|
||||
// Zero out blocks in a and b if needed
|
||||
if (has_op_mask) {
|
||||
auto& a_mask = inputs[inputs.size() - 2];
|
||||
mask_array(
|
||||
a_mask,
|
||||
ai,
|
||||
block_size_,
|
||||
i,
|
||||
M,
|
||||
K,
|
||||
a_transposed ? 1 : lda,
|
||||
a_transposed ? lda : 1);
|
||||
|
||||
auto& b_mask = inputs[inputs.size() - 1];
|
||||
mask_array(
|
||||
b_mask,
|
||||
bi,
|
||||
block_size_,
|
||||
i,
|
||||
K,
|
||||
N,
|
||||
b_transposed ? 1 : ldb,
|
||||
b_transposed ? ldb : 1);
|
||||
}
|
||||
|
||||
// Do matmul
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0, // alpha
|
||||
ai,
|
||||
lda,
|
||||
bi,
|
||||
ldb,
|
||||
0.0, // beta
|
||||
ci,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
|
||||
// Zero out blocks in out
|
||||
if (has_out_mask) {
|
||||
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GatherMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[GatherMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
auto check_transpose = [](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
// Get batch dims
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
size_t matrix_stride_out = M * N;
|
||||
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
};
|
||||
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
std::vector<int> batch_shape = get_batch_dims(out.shape());
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
|
||||
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
|
||||
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
|
||||
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
|
||||
|
||||
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
||||
for (int i = 0; i < batch_size_out; i++) {
|
||||
// Get index
|
||||
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
|
||||
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
|
||||
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0f, // alpha
|
||||
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
|
||||
ldb,
|
||||
0.0f, // beta
|
||||
out.data<float>() + matrix_stride_out * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -161,6 +161,13 @@ struct ArcTan {
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return std::atan2(y, x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
@@ -202,6 +209,12 @@ struct Ceil {
|
||||
};
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return std::conj(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
@@ -241,6 +254,13 @@ struct Exp {
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return expm1(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
@@ -599,4 +619,39 @@ struct Select {
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@@ -113,61 +113,6 @@ void AsType::eval(const std::vector<array>& inputs, array& out) {
|
||||
copy(in, out, ctype);
|
||||
}
|
||||
|
||||
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (!in.flags().row_contiguous) {
|
||||
// Just ensuring that inputs[0] came from the ops which would ensure the
|
||||
// input is row contiguous.
|
||||
throw std::runtime_error(
|
||||
"AsStrided must be used with row contiguous arrays only.");
|
||||
}
|
||||
|
||||
// Compute the flags given the shape and strides
|
||||
bool row_contiguous = true, col_contiguous = true;
|
||||
size_t r = 1, c = 1;
|
||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||
r *= shape_[i];
|
||||
c *= shape_[j];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
// TODO: Compute the contiguous flag in a better way cause now we are
|
||||
// unnecessarily strict.
|
||||
flags.contiguous = row_contiguous || col_contiguous;
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
|
||||
// There is no easy way to compute the actual data size so we use out.size().
|
||||
// The contiguous flag will almost certainly not be set so no code should
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -203,9 +148,15 @@ void Concatenate::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == complex64) {
|
||||
unary_fp(in, out, detail::Conjugate());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[conjugate] conjugate must be called on complex input.");
|
||||
}
|
||||
}
|
||||
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -232,81 +183,6 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void CustomVJP::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
i++, j++) {
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
}
|
||||
}
|
||||
|
||||
void Depends::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
numel *= inputs[0].shape(ax);
|
||||
}
|
||||
|
||||
if (inverted_) {
|
||||
numel = 1.0 / numel;
|
||||
}
|
||||
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
*out.data<bool>() = static_cast<bool>(numel);
|
||||
break;
|
||||
case uint8:
|
||||
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
|
||||
break;
|
||||
case uint16:
|
||||
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
|
||||
break;
|
||||
case uint32:
|
||||
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
|
||||
break;
|
||||
case uint64:
|
||||
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
|
||||
break;
|
||||
case int8:
|
||||
*out.data<int8_t>() = static_cast<int8_t>(numel);
|
||||
break;
|
||||
case int16:
|
||||
*out.data<int16_t>() = static_cast<int16_t>(numel);
|
||||
break;
|
||||
case int32:
|
||||
*out.data<int32_t>() = static_cast<int32_t>(numel);
|
||||
break;
|
||||
case int64:
|
||||
*out.data<int64_t>() = static_cast<int64_t>(numel);
|
||||
break;
|
||||
case float16:
|
||||
*out.data<float16_t>() = static_cast<float16_t>(numel);
|
||||
break;
|
||||
case float32:
|
||||
*out.data<float>() = static_cast<float>(numel);
|
||||
break;
|
||||
case bfloat16:
|
||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||
break;
|
||||
case complex64:
|
||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -359,6 +235,18 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Expm1());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[expm1] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -524,63 +412,6 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
}
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
auto [shape, _strides] = collapse_contiguous_dims(in);
|
||||
auto& strides = _strides[0];
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
std::vector<size_t> out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
int N = out.shape(i);
|
||||
if (j < shape.size() && shape[j] % N == 0) {
|
||||
shape[j] /= N;
|
||||
out_strides.push_back(shape[j] * strides[j]);
|
||||
j += (shape[j] == 1);
|
||||
} else if (N == 1) {
|
||||
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
|
||||
out_strides.push_back(out_strides.back());
|
||||
} else {
|
||||
copy_necessary = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.row_contiguous) {
|
||||
// For row contiguous reshapes:
|
||||
// - Shallow copy the buffer
|
||||
// - If reshaping into a vector (all singleton dimensions except one) it
|
||||
// becomes col contiguous again.
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -651,49 +482,6 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
|
||||
copy_needed |= strides_[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void Slice::shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
@@ -725,18 +513,6 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
}
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
@@ -774,58 +550,6 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
auto compute_new_flags = [](const auto& shape,
|
||||
const auto& strides,
|
||||
size_t in_data_size,
|
||||
auto flags) {
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.row_contiguous = true;
|
||||
flags.col_contiguous = true;
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
data_size *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in_data_size) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
return std::pair<decltype(flags), size_t>{flags, data_size};
|
||||
};
|
||||
|
||||
std::vector<int> indices(1, 0);
|
||||
indices.insert(indices.end(), indices_.begin(), indices_.end());
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
size_t offset = indices[i] * in.strides()[axis_];
|
||||
auto [new_flags, data_size] = compute_new_flags(
|
||||
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
|
||||
outputs[i].copy_shared_buffer(
|
||||
in, in.strides(), new_flags, data_size, offset);
|
||||
}
|
||||
}
|
||||
|
||||
void Square::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -842,11 +566,6 @@ void Sqrt::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -871,38 +590,4 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::vector<size_t> out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||
f_stride *= out.shape(i);
|
||||
flags.row_contiguous &=
|
||||
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -192,7 +192,7 @@ void _qmm_dispatch_typed(
|
||||
}
|
||||
|
||||
void _qmm_dispatch(
|
||||
array out,
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
@@ -253,6 +253,81 @@ void _qmm_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
void _bs_qmm_dispatch(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
|
||||
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
|
||||
|
||||
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
|
||||
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
|
||||
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out.data<float>() + i * M * N,
|
||||
x.data<float>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out.data<float16_t>() + i * M * N,
|
||||
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out.data<bfloat16_t>() + i * M * N,
|
||||
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
|
||||
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
|
||||
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -282,4 +357,45 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
}
|
||||
|
||||
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
|
||||
auto ensure_row_contiguous_last_dims = [](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@@ -10,7 +10,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename AccT>
|
||||
void softmax(const array& in, array& out) {
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
@@ -22,26 +22,36 @@ void softmax(const array& in, array& out) {
|
||||
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
T maximum = *current_in_ptr;
|
||||
AccT maximum = *current_in_ptr;
|
||||
for (int j = 0; j < N; j++, current_in_ptr++) {
|
||||
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum;
|
||||
maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
|
||||
: maximum;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
T normalizer = 0;
|
||||
AccT normalizer = 0;
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
|
||||
T expv = std::exp(*current_in_ptr - maximum);
|
||||
AccT expv = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += expv;
|
||||
*current_out_ptr = expv;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = expv;
|
||||
}
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_in_ptr = in_ptr;
|
||||
current_out_ptr = out_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++) {
|
||||
*current_out_ptr *= normalizer;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
auto v = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(v * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -91,13 +101,21 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float>(in, out);
|
||||
softmax<float, float>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
softmax<float16_t>(in, out);
|
||||
if (precise_) {
|
||||
softmax<float16_t, float>(in, out);
|
||||
} else {
|
||||
softmax<float16_t, float16_t>(in, out);
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
softmax<bfloat16_t>(in, out);
|
||||
if (precise_) {
|
||||
softmax<bfloat16_t, float>(in, out);
|
||||
} else {
|
||||
softmax<bfloat16_t, bfloat16_t>(in, out);
|
||||
}
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
|
@@ -3,7 +3,6 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack_helper.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -145,12 +144,4 @@ void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,31 +1,130 @@
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
function(make_jit_source SRC_FILE)
|
||||
# This function takes a metal header file,
|
||||
# runs the C preprocessesor on it, and makes
|
||||
# the processed contents available as a string in a C++ function
|
||||
# mlx::core::metal::${SRC_NAME}()
|
||||
#
|
||||
# To use the function, declare it in jit/includes.h and
|
||||
# include jit/includes.h.
|
||||
#
|
||||
# Additional arguments to this function are treated as dependencies
|
||||
# in the Cmake build system.
|
||||
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
|
||||
add_custom_command(
|
||||
OUTPUT jit/${SRC_NAME}.cpp
|
||||
COMMAND /bin/bash
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit
|
||||
${CMAKE_C_COMPILER}
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${SRC_FILE}
|
||||
"-D${MLX_METAL_VERSION}"
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
kernels/compiled_preamble.h
|
||||
kernels/unary.h
|
||||
kernels/binary.h
|
||||
)
|
||||
kernels/${SRC_FILE}.h
|
||||
${ARGN}
|
||||
)
|
||||
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
||||
add_dependencies(mlx ${SRC_NAME})
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
|
||||
)
|
||||
endfunction(make_jit_source)
|
||||
|
||||
add_custom_target(
|
||||
compiled_preamble
|
||||
DEPENDS compiled_preamble.cpp
|
||||
make_jit_source(
|
||||
utils
|
||||
kernels/bf16.h
|
||||
kernels/complex.h
|
||||
kernels/defines.h
|
||||
)
|
||||
make_jit_source(
|
||||
unary_ops
|
||||
kernels/erf.h
|
||||
kernels/expm1f.h
|
||||
)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(
|
||||
reduce_utils
|
||||
kernels/atomic.h
|
||||
kernels/reduction/ops.h
|
||||
)
|
||||
make_jit_source(scatter)
|
||||
make_jit_source(gather)
|
||||
|
||||
add_dependencies(mlx compiled_preamble)
|
||||
if (MLX_METAL_JIT)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
|
||||
)
|
||||
make_jit_source(arange)
|
||||
make_jit_source(copy)
|
||||
make_jit_source(unary)
|
||||
make_jit_source(binary)
|
||||
make_jit_source(binary_two)
|
||||
make_jit_source(ternary)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
make_jit_source(sort)
|
||||
make_jit_source(
|
||||
reduce
|
||||
kernels/reduction/reduce_all.h
|
||||
kernels/reduction/reduce_col.h
|
||||
kernels/reduction/reduce_row.h
|
||||
)
|
||||
make_jit_source(
|
||||
steel/gemm/gemm
|
||||
kernels/steel/utils.h
|
||||
kernels/steel/gemm/loader.h
|
||||
kernels/steel/gemm/mma.h
|
||||
kernels/steel/gemm/params.h
|
||||
kernels/steel/gemm/transforms.h
|
||||
)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
||||
make_jit_source(
|
||||
steel/gemm/kernels/steel_gemm_masked
|
||||
kernels/steel/defines.h
|
||||
)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||
make_jit_source(
|
||||
steel/conv/conv
|
||||
kernels/steel/utils.h
|
||||
kernels/steel/defines.h
|
||||
kernels/steel/gemm/mma.h
|
||||
kernels/steel/gemm/transforms.h
|
||||
kernels/steel/conv/params.h
|
||||
kernels/steel/conv/loader.h
|
||||
kernels/steel/conv/loaders/loader_channel_l.h
|
||||
kernels/steel/conv/loaders/loader_channel_n.h
|
||||
)
|
||||
make_jit_source(
|
||||
steel/conv/kernels/steel_conv
|
||||
)
|
||||
make_jit_source(
|
||||
steel/conv/kernels/steel_conv_general
|
||||
kernels/steel/defines.h
|
||||
kernels/steel/conv/loaders/loader_general.h
|
||||
)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
@@ -39,7 +138,8 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_PATH)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <unistd.h>
|
||||
@@ -139,10 +140,15 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
buffer_cache_(device_),
|
||||
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
|
||||
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()),
|
||||
max_pool_size_(block_limit_) {}
|
||||
buffer_cache_(device_) {
|
||||
auto memsize = std::get<size_t>(device_info()["memory_size"]);
|
||||
block_limit_ =
|
||||
std::min(1.5 * device_->recommendedMaxWorkingSetSize(), 0.95 * memsize);
|
||||
gc_limit_ = std::min(
|
||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
|
||||
block_limit_);
|
||||
max_pool_size_ = block_limit_;
|
||||
}
|
||||
|
||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
std::swap(limit, max_pool_size_);
|
||||
@@ -164,6 +170,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
// More helpful message if maximum buffer length is exceeded
|
||||
if (size > device_->maxBufferLength()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Attempting to allocate " << size << " bytes which is greater than"
|
||||
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
|
||||
<< " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Align up memory
|
||||
if (size > vm_page_size) {
|
||||
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
|
||||
@@ -208,6 +223,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
void MetalAllocator::clear_cache() {
|
||||
std::unique_lock lk(mutex_);
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
std::unique_lock lk(mutex_);
|
||||
@@ -238,9 +258,15 @@ size_t get_active_memory() {
|
||||
size_t get_peak_memory() {
|
||||
return allocator().get_peak_memory();
|
||||
}
|
||||
void reset_peak_memory() {
|
||||
allocator().reset_peak_memory();
|
||||
}
|
||||
size_t get_cache_memory() {
|
||||
return allocator().get_cache_memory();
|
||||
}
|
||||
void clear_cache() {
|
||||
return allocator().clear_cache();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
|
@@ -26,6 +26,7 @@ class BufferCache {
|
||||
size_t cache_size() {
|
||||
return pool_size_;
|
||||
}
|
||||
void clear();
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
@@ -37,7 +38,6 @@ class BufferCache {
|
||||
MTL::Buffer* buf;
|
||||
};
|
||||
|
||||
void clear();
|
||||
void add_at_head(BufferHolder* to_add);
|
||||
void remove_from_list(BufferHolder* to_remove);
|
||||
|
||||
@@ -62,11 +62,16 @@ class MetalAllocator : public allocator::Allocator {
|
||||
size_t get_peak_memory() {
|
||||
return peak_memory_;
|
||||
};
|
||||
void reset_peak_memory() {
|
||||
std::unique_lock lk(mutex_);
|
||||
peak_memory_ = 0;
|
||||
};
|
||||
size_t get_cache_memory() {
|
||||
return buffer_cache_.cache_size();
|
||||
};
|
||||
size_t set_cache_limit(size_t limit);
|
||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
MTL::Device* device_;
|
||||
|
322
mlx/backend/metal/binary.cpp
Normal file
322
mlx/backend/metal/binary.cpp
Normal file
@@ -0,0 +1,322 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else {
|
||||
kname << "n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
// otherwise it goes to the second output
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
|
||||
compute_encoder.set_input_array(
|
||||
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
|
||||
compute_encoder.set_output_array(outputs[0], 2);
|
||||
compute_encoder.set_output_array(outputs[1], 3);
|
||||
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt, true);
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else {
|
||||
kname << "n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_kernel(d, kernel_name, a, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
compute_encoder.set_input_array(donate_a ? out : a, 0);
|
||||
compute_encoder.set_input_array(donate_b ? out : b, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads =
|
||||
bopt == BinaryOpType::General ? out.size() : out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "add");
|
||||
}
|
||||
|
||||
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "arctan2");
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op(inputs, out, "bitwise_and");
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op(inputs, out, "bitwise_or");
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op(inputs, out, "bitwise_xor");
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op(inputs, out, "left_shift");
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op(inputs, out, "right_shift");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "div");
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
binary_op(inputs, outputs, "divmod");
|
||||
}
|
||||
|
||||
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "rem");
|
||||
}
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||
}
|
||||
|
||||
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "ge");
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "geq");
|
||||
}
|
||||
|
||||
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "le");
|
||||
}
|
||||
|
||||
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "leq");
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "land");
|
||||
}
|
||||
|
||||
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lor");
|
||||
}
|
||||
|
||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lae");
|
||||
}
|
||||
|
||||
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "max");
|
||||
}
|
||||
|
||||
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "min");
|
||||
}
|
||||
|
||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
}
|
||||
|
||||
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "neq");
|
||||
}
|
||||
|
||||
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "pow");
|
||||
}
|
||||
|
||||
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "sub");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -4,8 +4,8 @@
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/compiled_preamble.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -190,7 +190,8 @@ void Compiled::eval_gpu(
|
||||
// If not we have to build it ourselves
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel;
|
||||
kernel << metal::get_kernel_preamble() << std::endl;
|
||||
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
|
||||
<< metal::ternary_ops();
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous",
|
||||
@@ -229,14 +230,7 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
for (auto& x : inputs) {
|
||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
||||
!is_scalar(x)) {
|
||||
contiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
@@ -296,7 +290,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Put the inputs in
|
||||
@@ -307,7 +301,7 @@ void Compiled::eval_gpu(
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
compute_encoder.set_input_array(x, cnt++);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
@@ -317,32 +311,12 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
{
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, true);
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
compute_encoder.set_output_array(x, cnt++);
|
||||
}
|
||||
|
||||
// Put the output shape and strides in
|
||||
@@ -363,7 +337,7 @@ void Compiled::eval_gpu(
|
||||
MTL::Size grid_dims(nthreads, 1, 1);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
@@ -374,7 +348,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,9 +0,0 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* get_kernel_preamble();
|
||||
|
||||
}
|
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
@@ -41,12 +42,12 @@ void explicit_gemm_conv_ND_gpu(
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, in_unfolded, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
|
||||
@@ -59,7 +60,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
@@ -89,6 +90,90 @@ void explicit_gemm_conv_ND_gpu(
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
void explicit_gemm_conv_group_ND_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<N>& conv_params) {
|
||||
const int groups = conv_params.groups;
|
||||
const int C_per_group = conv_params.C / conv_params.groups;
|
||||
const int O_per_group = conv_params.O / conv_params.groups;
|
||||
// Get gemm shapes
|
||||
const int implicit_M = out.size() / conv_params.O;
|
||||
const int implicit_K = wt.size() / conv_params.O;
|
||||
const int implicit_N = O_per_group;
|
||||
|
||||
int kernel_size = 1;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
kernel_size *= conv_params.wS[i];
|
||||
}
|
||||
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
|
||||
<< N;
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
int tgp_y = 256 / tgp_x;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Transpose kernel weights so that we can slice them by contiguous chunks
|
||||
// of channel groups.
|
||||
array wt_view(
|
||||
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
|
||||
wt_view.copy_shared_buffer(
|
||||
wt,
|
||||
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
|
||||
wt.flags(),
|
||||
wt.size());
|
||||
|
||||
// Materialize
|
||||
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
|
||||
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
|
||||
return steel_matmul_conv_groups(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_unfolded,
|
||||
/*b = */ wt_transpose,
|
||||
/*c = */ out,
|
||||
/*M = */ implicit_M,
|
||||
/*N = */ implicit_N,
|
||||
/*K = */ implicit_K,
|
||||
/*a_cols = */ implicit_K * groups,
|
||||
/*b_cols = */ implicit_K,
|
||||
/*out_cols = */ implicit_N * groups,
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/* groups = */ groups,
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
void conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -99,6 +184,7 @@ void conv_1D_gpu(
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
int groups,
|
||||
bool flip) {
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
@@ -118,11 +204,15 @@ void conv_1D_gpu(
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||
/* const int groups = */ 1,
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip};
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
if (groups > 1) {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
void slow_conv_2D_gpu(
|
||||
@@ -140,7 +230,7 @@ void slow_conv_2D_gpu(
|
||||
<< "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -153,12 +243,12 @@ void slow_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(bm, bn, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
@@ -168,15 +258,19 @@ void implicit_gemm_conv_2D_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
const int groups = conv_params.groups;
|
||||
const int C_per_group = conv_params.C / conv_params.groups;
|
||||
const int O_per_group = conv_params.O / conv_params.groups;
|
||||
|
||||
// Deduce implicit gemm size
|
||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
int implicit_N = conv_params.O;
|
||||
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
||||
const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
const int implicit_N = O_per_group;
|
||||
const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;
|
||||
|
||||
// Determine block and warp tiles
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
|
||||
int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;
|
||||
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
|
||||
int bk = 16;
|
||||
|
||||
@@ -192,15 +286,15 @@ void implicit_gemm_conv_2D_gpu(
|
||||
|
||||
// Fix small channel specialization
|
||||
int n_channel_specialization = 0;
|
||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||
int channel_k_iters = ((C_per_group + bk - 1) / bk);
|
||||
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
|
||||
|
||||
if (conv_params.C <= 2) {
|
||||
if (C_per_group <= 2) {
|
||||
gemm_k_iters = (implicit_K + bk - 1) / bk;
|
||||
n_channel_specialization = conv_params.C;
|
||||
} else if (conv_params.C <= 4) {
|
||||
n_channel_specialization = C_per_group;
|
||||
} else if (C_per_group <= 4) {
|
||||
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
|
||||
n_channel_specialization = conv_params.C;
|
||||
n_channel_specialization = C_per_group;
|
||||
}
|
||||
|
||||
bool small_filter = (!n_channel_specialization) &&
|
||||
@@ -241,8 +335,18 @@ void implicit_gemm_conv_2D_gpu(
|
||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_conv_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
out,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
n_channel_specialization,
|
||||
small_filter);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
@@ -251,19 +355,19 @@ void implicit_gemm_conv_2D_gpu(
|
||||
size_t grid_dim_x = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_general_gpu(
|
||||
@@ -394,8 +498,9 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel =
|
||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
@@ -408,9 +513,9 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
@@ -423,7 +528,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void winograd_conv_2D_gpu(
|
||||
@@ -511,12 +616,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, wt, 0);
|
||||
set_array_buffer(compute_encoder, filt_wg, 1);
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
compute_encoder.set_output_array(filt_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(&C_c, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&O_c, sizeof(int), 3);
|
||||
@@ -524,7 +629,7 @@ void winograd_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(32, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do input transform
|
||||
@@ -539,12 +644,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in_padded, 0);
|
||||
set_array_buffer(compute_encoder, inp_wg, 1);
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
compute_encoder.set_output_array(inp_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
@@ -552,7 +657,7 @@ void winograd_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do batched gemm
|
||||
@@ -587,12 +692,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, out_wg, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
@@ -600,7 +705,7 @@ void winograd_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -614,6 +719,7 @@ void conv_2D_gpu(
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
const int groups,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
@@ -629,12 +735,12 @@ void conv_2D_gpu(
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
|
||||
{in.strides(0), in.strides(1), in.strides(2), in.strides(3)},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
{wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
/* const int groups = */ 1,
|
||||
{out.strides(0), out.strides(1), out.strides(2), out.strides(3)},
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip,
|
||||
};
|
||||
|
||||
@@ -646,6 +752,18 @@ void conv_2D_gpu(
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 512;
|
||||
bool channels_med = (conv_params.C + conv_params.O) >= 256;
|
||||
|
||||
if (groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
// Direct to winograd conv
|
||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
@@ -670,6 +788,56 @@ void conv_2D_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void conv_3D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<3> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(4),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
|
||||
/* const int kdil[NDIM] = */
|
||||
{wt_dilation[0], wt_dilation[1], wt_dilation[2]},
|
||||
/* const int idil[NDIM] = */
|
||||
{in_dilation[0], in_dilation[1], in_dilation[2]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0],
|
||||
in.strides()[1],
|
||||
in.strides()[2],
|
||||
in.strides()[3],
|
||||
in.strides()[4]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0],
|
||||
wt.strides()[1],
|
||||
wt.strides()[2],
|
||||
wt.strides()[3],
|
||||
wt.strides()[4]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0],
|
||||
out.strides()[1],
|
||||
out.strides()[2],
|
||||
out.strides()[3],
|
||||
out.strides()[4]},
|
||||
/* const int groups = */ 1,
|
||||
/* const bool flip = */ flip,
|
||||
};
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -694,8 +862,23 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
wt = arr_copy;
|
||||
}
|
||||
|
||||
// 3D conv
|
||||
if (out.ndim() == 5) {
|
||||
conv_3D_gpu(
|
||||
s,
|
||||
d,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_,
|
||||
copies);
|
||||
}
|
||||
// 2D conv
|
||||
if (out.ndim() == 4) {
|
||||
else if (out.ndim() == 4) {
|
||||
conv_2D_gpu(
|
||||
s,
|
||||
d,
|
||||
@@ -706,6 +889,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_,
|
||||
copies);
|
||||
}
|
||||
@@ -721,6 +905,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_);
|
||||
}
|
||||
// Throw error
|
||||
|
@@ -4,12 +4,14 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
@@ -62,36 +64,43 @@ void copy_gpu_inplace(
|
||||
auto& strides_out_ = strides[1];
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << "scopy";
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << "vcopy";
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "gcopy";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kname << "ggcopy";
|
||||
break;
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << "s";
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << "v";
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "g";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kname << "gg";
|
||||
break;
|
||||
}
|
||||
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
|
||||
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
}
|
||||
kname << "_copy";
|
||||
kname << type_to_name(in) << type_to_name(out);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
kname << type_to_name(in) << type_to_name(out);
|
||||
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
|
||||
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
auto kernel = get_copy_kernel(d, kernel_name, in, out);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
|
||||
inp_offset *= size_of(in.dtype());
|
||||
out_offset *= size_of(out.dtype());
|
||||
|
||||
set_array_buffer(compute_encoder, donate_in ? out : in, inp_offset, 0);
|
||||
set_array_buffer(compute_encoder, out, out_offset, 1);
|
||||
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
|
||||
compute_encoder.set_output_array(out, 1, out_offset);
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
int ndim = shape.size();
|
||||
@@ -106,7 +115,7 @@ void copy_gpu_inplace(
|
||||
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||
}
|
||||
|
||||
@@ -126,7 +135,7 @@ void copy_gpu_inplace(
|
||||
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
@@ -135,7 +144,7 @@ void copy_gpu_inplace(
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,16 +1,19 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
|
||||
#include <sys/sysctl.h>
|
||||
|
||||
#define NS_PRIVATE_IMPLEMENTATION
|
||||
#define CA_PRIVATE_IMPLEMENTATION
|
||||
#define MTL_PRIVATE_IMPLEMENTATION
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
@@ -22,9 +25,18 @@ namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
constexpr auto get_metal_version() {
|
||||
#if defined METAL_3_1
|
||||
return MTL::LanguageVersion3_1;
|
||||
#else
|
||||
return MTL::LanguageVersion3_0;
|
||||
#endif
|
||||
}
|
||||
|
||||
auto load_device() {
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||
@@ -34,7 +46,6 @@ auto load_device() {
|
||||
}
|
||||
return device;
|
||||
}
|
||||
|
||||
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
|
||||
MTL::Device* device,
|
||||
const char* path) {
|
||||
@@ -113,6 +124,33 @@ MTL::Library* load_library(
|
||||
|
||||
} // namespace
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
num_dispatches++;
|
||||
enc->dispatchThreadgroups(grid_dims, group_dims);
|
||||
maybe_split();
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
num_dispatches++;
|
||||
enc->dispatchThreads(grid_dims, group_dims);
|
||||
maybe_split();
|
||||
}
|
||||
|
||||
void CommandEncoder::maybe_split() {
|
||||
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
num_dispatches = 0;
|
||||
outputs.clear();
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
}
|
||||
}
|
||||
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
@@ -127,9 +165,6 @@ Device::~Device() {
|
||||
for (auto& b : buffer_map_) {
|
||||
b.second.second->release();
|
||||
}
|
||||
for (auto& e : encoder_map_) {
|
||||
e.second->release();
|
||||
}
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
@@ -166,27 +201,26 @@ void Device::increment_command_buffer_ops(int index) {
|
||||
|
||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
|
||||
}
|
||||
if (bit == buffer_map_.end()) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::new_command_buffer(int index) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||
|
||||
if (!cb) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Unable to create new command buffer");
|
||||
}
|
||||
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
bit = buffer_map_.insert({index, {0, cb}}).first;
|
||||
}
|
||||
|
||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||
|
||||
if (!cb) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Unable to create new command buffer");
|
||||
}
|
||||
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
return buffer_map_.insert({index, {0, cb}}).first->second.second;
|
||||
return bit->second.second;
|
||||
}
|
||||
|
||||
void Device::commit_command_buffer(int index) {
|
||||
@@ -197,24 +231,17 @@ void Device::commit_command_buffer(int index) {
|
||||
}
|
||||
|
||||
void Device::end_encoding(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit != encoder_map_.end()) {
|
||||
eit->second->endEncoding();
|
||||
eit->second->release();
|
||||
encoder_map_.erase(eit);
|
||||
}
|
||||
encoder_map_.erase(index);
|
||||
}
|
||||
|
||||
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
|
||||
CommandEncoder& Device::get_command_encoder(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit == encoder_map_.end()) {
|
||||
auto cb = get_command_buffer(index);
|
||||
auto compute_encoder = cb->computeCommandEncoder();
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
compute_encoder->retain();
|
||||
eit = encoder_map_.insert({index, compute_encoder}).first;
|
||||
eit =
|
||||
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
|
||||
}
|
||||
return eit->second;
|
||||
return *(eit->second);
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
@@ -256,13 +283,16 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
|
||||
|
||||
NS::Error* error = nullptr;
|
||||
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
|
||||
auto options = MTL::CompileOptions::alloc()->init();
|
||||
options->setFastMathEnabled(false);
|
||||
options->setLanguageVersion(get_metal_version());
|
||||
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
|
||||
options->release();
|
||||
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load build metal library from source"
|
||||
<< "\n";
|
||||
msg << "[metal::Device] Unable to build metal library from source" << "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
@@ -281,8 +311,7 @@ MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load build stitched metal library"
|
||||
<< "\n";
|
||||
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
@@ -340,7 +369,6 @@ MTL::Function* Device::get_function_(
|
||||
}
|
||||
|
||||
mtl_func_consts->release();
|
||||
desc->release();
|
||||
|
||||
return mtl_function;
|
||||
}
|
||||
@@ -509,11 +537,13 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
// Compile kernel to compute pipeline
|
||||
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
||||
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
|
||||
|
||||
mtl_function->release();
|
||||
mtl_linked_funcs->release();
|
||||
|
||||
// Add kernel to cache
|
||||
kernel_map_.insert({kname, kernel});
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
@@ -540,11 +570,12 @@ Device& device(mlx::core::Device) {
|
||||
return metal_device;
|
||||
}
|
||||
|
||||
std::shared_ptr<void> new_scoped_memory_pool() {
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
||||
auto dtor = [](void* ptr) {
|
||||
static_cast<NS::AutoreleasePool*>(ptr)->release();
|
||||
};
|
||||
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
|
||||
return std::unique_ptr<void, std::function<void(void*)>>(
|
||||
NS::AutoreleasePool::alloc()->init(), dtor);
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
@@ -553,4 +584,23 @@ void new_stream(Stream stream) {
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::variant<std::string, size_t>>
|
||||
device_info() {
|
||||
auto raw_device = device(default_device()).mtl_device();
|
||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||
|
||||
int mib[] = {CTL_HW, HW_MEMSIZE};
|
||||
size_t memsize = 0;
|
||||
size_t length = sizeof(memsize);
|
||||
|
||||
sysctl(mib, 2, &memsize, &length, NULL, 0);
|
||||
|
||||
return {
|
||||
{"architecture", arch},
|
||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||
{"max_recommended_working_set_size",
|
||||
raw_device->recommendedMaxWorkingSetSize()},
|
||||
{"memory_size", memsize}};
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,10 +7,12 @@
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
@@ -34,6 +36,84 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
};
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
struct ConcurrentContext {
|
||||
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
|
||||
enc.concurrent = true;
|
||||
}
|
||||
~ConcurrentContext() {
|
||||
enc.concurrent = false;
|
||||
enc.outputs.insert(
|
||||
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
|
||||
enc.concurrent_outputs.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
CommandEncoder& enc;
|
||||
};
|
||||
|
||||
MTL::ComputeCommandEncoder* operator->() {
|
||||
return enc;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0) {
|
||||
auto r_buf =
|
||||
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
// Insert a barrier
|
||||
enc->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs.erase(it);
|
||||
}
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent) {
|
||||
concurrent_outputs.insert(buf);
|
||||
} else {
|
||||
outputs.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
|
||||
~CommandEncoder() {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
}
|
||||
|
||||
private:
|
||||
void maybe_split();
|
||||
|
||||
int num_dispatches{0};
|
||||
MTL::CommandBuffer* cbuf;
|
||||
MTL::ComputeCommandEncoder* enc;
|
||||
bool concurrent{false};
|
||||
std::unordered_set<MTL::Resource*> outputs;
|
||||
std::unordered_set<MTL::Resource*> concurrent_outputs;
|
||||
};
|
||||
|
||||
class Device {
|
||||
public:
|
||||
Device();
|
||||
@@ -46,12 +126,11 @@ class Device {
|
||||
};
|
||||
|
||||
void new_queue(int index);
|
||||
MTL::CommandBuffer* new_command_buffer(int index);
|
||||
MTL::CommandBuffer* get_command_buffer(int index);
|
||||
int get_command_buffer_ops(int index);
|
||||
void increment_command_buffer_ops(int index);
|
||||
void commit_command_buffer(int index);
|
||||
MTL::ComputeCommandEncoder* get_command_encoder(int index);
|
||||
CommandEncoder& get_command_encoder(int index);
|
||||
void end_encoding(int index);
|
||||
|
||||
void register_library(
|
||||
@@ -132,7 +211,7 @@ class Device {
|
||||
MTL::Device* device_;
|
||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||
std::unordered_map<int32_t, MTL::ComputeCommandEncoder*> encoder_map_;
|
||||
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
std::mutex mtx_;
|
||||
|
30
mlx/backend/metal/event.cpp
Normal file
30
mlx/backend/metal/event.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
Event::Event(const Stream& stream) : stream_(stream) {
|
||||
auto dtor = [](void* ptr) {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
static_cast<MTL::SharedEvent*>(ptr)->release();
|
||||
};
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
event_ = std::shared_ptr<void>(
|
||||
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
|
||||
}
|
||||
|
||||
void Event::wait() {
|
||||
if (!static_cast<MTL::SharedEvent*>(raw_event().get())
|
||||
->waitUntilSignaledValue(value(), -1)) {
|
||||
throw std::runtime_error("[Event::wait] Timed out");
|
||||
}
|
||||
}
|
||||
|
||||
void Event::signal() {
|
||||
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,12 +1,106 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& in = inputs[0];
|
||||
throw std::runtime_error("[FFT] NYI for Metal backend.");
|
||||
|
||||
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
|
||||
in.dtype() != complex64 || out.dtype() != complex64) {
|
||||
// Could also fallback to CPU implementation here.
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
|
||||
}
|
||||
|
||||
size_t n = in.shape(axes_[0]);
|
||||
|
||||
if (!is_power_of_2(n) || n > 2048 || n < 4) {
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
|
||||
}
|
||||
|
||||
// Make sure that the array is contiguous and has stride 1 in the FFT dim
|
||||
std::vector<array> copies;
|
||||
auto check_input = [this, &copies, &s](const array& x) {
|
||||
// TODO: Pass the strides to the kernel so
|
||||
// we can avoid the copy when x is not contiguous.
|
||||
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
|
||||
x.flags().col_contiguous;
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
std::vector<size_t> strides;
|
||||
size_t cur_stride = x.shape(axes_[0]);
|
||||
for (int axis = 0; axis < x.ndim(); axis++) {
|
||||
if (axis == axes_[0]) {
|
||||
strides.push_back(1);
|
||||
} else {
|
||||
strides.push_back(cur_stride);
|
||||
cur_stride *= x.shape(axis);
|
||||
}
|
||||
}
|
||||
|
||||
auto flags = x.flags();
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
|
||||
f_stride *= x.shape(i);
|
||||
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
|
||||
b_stride *= x.shape(ri);
|
||||
}
|
||||
// This is probably over-conservative
|
||||
flags.contiguous = false;
|
||||
|
||||
x_copy.set_data(
|
||||
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
|
||||
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(inputs[0]);
|
||||
|
||||
// TODO: allow donation here
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
in_contiguous.data_size(),
|
||||
in_contiguous.strides(),
|
||||
in_contiguous.flags());
|
||||
|
||||
// We use n / 4 threads by default since radix-4
|
||||
// is the largest single threaded radix butterfly
|
||||
// we currently implement.
|
||||
size_t m = n / 4;
|
||||
size_t batch = in.size() / in.shape(axes_[0]);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "fft_" << n;
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in_contiguous, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
auto group_dims = MTL::Size(1, m, 1);
|
||||
auto grid_dims = MTL::Size(batch, m, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,24 +1,35 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/indexing.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
constexpr int METAL_MAX_INDEX_ARRAYS = 20;
|
||||
|
||||
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
|
||||
} // namespace
|
||||
std::pair<std::string, std::string> make_index_args(
|
||||
const std::string& idx_type,
|
||||
int nidx) {
|
||||
std::ostringstream idx_args;
|
||||
std::ostringstream idx_arr;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_args << fmt::format(
|
||||
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
|
||||
idx_arr << fmt::format("idx{0}", i);
|
||||
if (i < nidx - 1) {
|
||||
idx_args << "\n";
|
||||
idx_arr << ",";
|
||||
}
|
||||
}
|
||||
return {idx_args.str(), idx_arr.str()};
|
||||
}
|
||||
|
||||
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& src = inputs[0];
|
||||
@@ -42,15 +53,41 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
std::ostringstream kname;
|
||||
std::string lib_name;
|
||||
std::string kernel_name;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
|
||||
if (idx_ndim <= 1) {
|
||||
kname << "_" << idx_ndim;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
|
||||
<< "_" << idx_ndim;
|
||||
lib_name = kname.str();
|
||||
kernel_name = lib_name;
|
||||
}
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gather();
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
// Index dimension specializations
|
||||
kernel_source << fmt::format(
|
||||
gather_kernels,
|
||||
type_to_name(out) + idx_type_name,
|
||||
out_type_str,
|
||||
idx_type_str,
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr,
|
||||
idx_ndim);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
size_t slice_size = 1;
|
||||
@@ -81,8 +118,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, src, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(src, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
// Set source info
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
|
||||
@@ -102,12 +139,12 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -139,10 +176,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Get kernel name
|
||||
std::ostringstream kname;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
bool index_nd1_specialization = (idx_ndim == 1);
|
||||
|
||||
@@ -159,32 +192,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
index_nd1_specialization &= inputs[i].flags().row_contiguous;
|
||||
}
|
||||
|
||||
if (index_nd1_specialization) {
|
||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||
} else {
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
}
|
||||
std::string lib_name;
|
||||
std::string kernel_name;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
kname << "_none";
|
||||
op_name = "none";
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
kname << "_sum";
|
||||
op_name = "sum";
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
kname << "_prod";
|
||||
op_name = "prod";
|
||||
break;
|
||||
case Scatter::Max:
|
||||
kname << "_max";
|
||||
op_name = "max";
|
||||
break;
|
||||
case Scatter::Min:
|
||||
kname << "_min";
|
||||
op_name = "min";
|
||||
break;
|
||||
}
|
||||
kname << "_" << nidx;
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
{
|
||||
std::ostringstream kname;
|
||||
if (index_nd1_specialization) {
|
||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||
} else {
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
}
|
||||
kname << "_" << op_name << "_" << nidx;
|
||||
lib_name = kname.str();
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
<< metal::scatter();
|
||||
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
||||
std::string op_type;
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
op_type = "None";
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
op_type = "Sum<{0}>";
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
op_type = "Prod<{0}>";
|
||||
break;
|
||||
case Scatter::Max:
|
||||
op_type = "Max<{0}>";
|
||||
break;
|
||||
case Scatter::Min:
|
||||
op_type = "Min<{0}>";
|
||||
break;
|
||||
}
|
||||
if (reduce_type_ != Scatter::None) {
|
||||
op_type = fmt::format(op_type, out_type_str);
|
||||
}
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
kernel_source << fmt::format(
|
||||
scatter_kernels,
|
||||
type_to_name(out) + idx_type_name + "_" + op_name,
|
||||
out_type_str,
|
||||
idx_type_str,
|
||||
op_type,
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
|
||||
auto& upd = inputs.back();
|
||||
size_t nthreads = upd.size();
|
||||
@@ -192,8 +279,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(upd, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set update info
|
||||
uint upd_ndim = upd.ndim();
|
||||
@@ -209,14 +296,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Collect all idx shapes and strides into one place
|
||||
@@ -279,14 +366,14 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
9
mlx/backend/metal/jit/arange.h
Normal file
9
mlx/backend/metal/jit/arange.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view arange_kernels = R"(
|
||||
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
|
||||
constant const {1}& start,
|
||||
constant const {1}& step,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
87
mlx/backend/metal/jit/binary.h
Normal file
87
mlx/backend/metal/jit/binary.h
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
98
mlx/backend/metal/jit/binary_two.h
Normal file
98
mlx/backend/metal/jit/binary_two.h
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_two_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
100
mlx/backend/metal/jit/copy.h
Normal file
100
mlx/backend/metal/jit/copy.h
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view copy_kernels = R"(
|
||||
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||
copy_g_nd<{1}, {2}, 4>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg4_{0}")]] [[kernel]] void
|
||||
copy_gg_nd<{1}, {2}, 4>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||
copy_g_nd<{1}, {2}, 5>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg5_{0}")]] [[kernel]] void
|
||||
copy_gg_nd<{1}, {2}, 5>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg1_{0}")]] [[kernel]] void
|
||||
copy_gg_nd1<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("gg2_{0}")]] [[kernel]] void
|
||||
copy_gg_nd2<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]);
|
||||
template [[host_name("gg3_{0}")]] [[kernel]] void
|
||||
copy_gg_nd3<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
)";
|
34
mlx/backend/metal/jit/includes.h
Normal file
34
mlx/backend/metal/jit/includes.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* utils();
|
||||
const char* binary_ops();
|
||||
const char* unary_ops();
|
||||
const char* ternary_ops();
|
||||
const char* reduce_utils();
|
||||
const char* gather();
|
||||
const char* scatter();
|
||||
|
||||
const char* arange();
|
||||
const char* unary();
|
||||
const char* binary();
|
||||
const char* binary_two();
|
||||
const char* copy();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* softmax();
|
||||
const char* sort();
|
||||
const char* reduce();
|
||||
|
||||
const char* gemm();
|
||||
const char* steel_gemm_fused();
|
||||
const char* steel_gemm_masked();
|
||||
const char* steel_gemm_splitk();
|
||||
const char* conv();
|
||||
const char* steel_conv();
|
||||
const char* steel_conv_general();
|
||||
|
||||
} // namespace mlx::core::metal
|
81
mlx/backend/metal/jit/indexing.h
Normal file
81
mlx/backend/metal/jit/indexing.h
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view gather_kernels = R"(
|
||||
[[kernel]] void gather{0}_{3}_{6}(
|
||||
const device {1}* src [[buffer(0)]],
|
||||
device {1}* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const constant int* idx_shapes [[buffer(7)]],
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int& idx_ndim [[buffer(9)]],
|
||||
{4}
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {{
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
|
||||
return gather_impl<{1}, {2}, {3}, {6}>(
|
||||
src,
|
||||
out,
|
||||
src_shape,
|
||||
src_strides,
|
||||
src_ndim,
|
||||
slice_sizes,
|
||||
axes,
|
||||
idxs,
|
||||
index,
|
||||
grid_dim);
|
||||
}}
|
||||
)";
|
||||
|
||||
constexpr std::string_view scatter_kernels = R"(
|
||||
[[kernel]] void scatter_1d_index{0}_{4}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
||||
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
||||
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
|
||||
}}
|
||||
|
||||
[[kernel]] void scatter{0}_{4}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
const constant size_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int* out_shape [[buffer(7)]],
|
||||
const constant size_t* out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const constant int* idx_shapes [[buffer(11)]],
|
||||
const constant size_t* idx_strides [[buffer(12)]],
|
||||
const constant int& idx_ndim [[buffer(13)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
|
||||
return scatter_impl<{1}, {2}, {3}, {4}>(
|
||||
updates,
|
||||
out,
|
||||
upd_shape,
|
||||
upd_strides,
|
||||
upd_ndim,
|
||||
upd_size,
|
||||
out_shape,
|
||||
out_strides,
|
||||
out_ndim,
|
||||
axes,
|
||||
idxs,
|
||||
gid);
|
||||
}}
|
||||
)";
|
168
mlx/backend/metal/jit/reduce.h
Normal file
168
mlx/backend/metal/jit/reduce.h
Normal file
@@ -0,0 +1,168 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view reduce_init_kernels = R"(
|
||||
[[kernel]] void {0}(
|
||||
device {1}* out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {{
|
||||
out[tid] = {2}<{1}>::init;
|
||||
}}
|
||||
)";
|
||||
|
||||
constexpr std::string_view reduce_kernels = R"(
|
||||
template [[host_name("all_{0}")]] [[kernel]] void
|
||||
all_reduce<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("colGeneral_{0}")]] [[kernel]] void
|
||||
col_reduce_general<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup {2}* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
template [[host_name("colSmall_{0}")]] [[kernel]] void
|
||||
col_reduce_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
|
||||
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void
|
||||
row_reduce_general_med<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("rowGeneral_{0}")]] [[kernel]] void
|
||||
row_reduce_general<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view reduce_non_atomic_kernels = R"(
|
||||
template [[host_name("allNoAtomics_{0}")]] [[kernel]] void
|
||||
all_reduce_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void
|
||||
col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup {2}* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
template [[host_name("colSmall_{0}")]] [[kernel]] void
|
||||
col_reduce_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
|
||||
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void
|
||||
row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
26
mlx/backend/metal/jit/scan.h
Normal file
26
mlx/backend/metal/jit/scan.h
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view scan_kernels = R"(
|
||||
template [[host_name("contig_{0}")]] [[kernel]] void
|
||||
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
template [[host_name("strided_{0}")]] [[kernel]] void
|
||||
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
const constant size_t& stride [[buffer(3)]],
|
||||
uint2 gid [[thread_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]]);
|
||||
)";
|
23
mlx/backend/metal/jit/softmax.h
Normal file
23
mlx/backend/metal/jit/softmax.h
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view softmax_kernels = R"(
|
||||
template [[host_name("block_{0}")]] [[kernel]] void
|
||||
softmax_single_row<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("looped_{0}")]] [[kernel]] void
|
||||
softmax_looped<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
81
mlx/backend/metal/jit/sort.h
Normal file
81
mlx/backend/metal/jit/sort.h
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view block_sort_kernels = R"(
|
||||
template [[host_name("carg_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("ncarg_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("c_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("nc_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view multiblock_sort_kernels = R"(
|
||||
template [[host_name("sort_{0}")]] [[kernel]] void
|
||||
mb_block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {1}* out_vals [[buffer(1)]],
|
||||
device {2}* out_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("partition_{0}")]] [[kernel]] void
|
||||
mb_block_partition<{1}, {2}, true, {3}, {4}>(
|
||||
device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals [[buffer(1)]],
|
||||
const device {2}* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]);
|
||||
template [[host_name("merge_{0}")]] [[kernel]] void
|
||||
mb_block_merge<{1}, {2}, true, {3}, {4}>(
|
||||
const device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals_in [[buffer(1)]],
|
||||
const device {2}* dev_idxs_in [[buffer(2)]],
|
||||
device {1}* dev_vals_out [[buffer(3)]],
|
||||
device {2}* dev_idxs_out [[buffer(4)]],
|
||||
const constant int& size_sorted_axis [[buffer(5)]],
|
||||
const constant int& merge_tiles [[buffer(6)]],
|
||||
const constant int& num_tiles [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
32
mlx/backend/metal/jit/steel_conv.h
Normal file
32
mlx/backend/metal/jit/steel_conv.h
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view steel_conv_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_conv_general_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
106
mlx/backend/metal/jit/steel_gemm.h
Normal file
106
mlx/backend/metal/jit/steel_gemm.h
Normal file
@@ -0,0 +1,106 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view steel_gemm_fused_kernels = R"(
|
||||
template [[host_name("{name}")]]
|
||||
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
|
||||
const device {itype} *A [[buffer(0)]],
|
||||
const device {itype} *B [[buffer(1)]],
|
||||
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
|
||||
device {itype} *D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_masked_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
block_masked_gemm<
|
||||
{itype},
|
||||
{outmasktype},
|
||||
{opmasktype},
|
||||
{bm},
|
||||
{bn},
|
||||
{bk},
|
||||
{wm},
|
||||
{wn},
|
||||
{trans_a},
|
||||
{trans_b},
|
||||
{mn_aligned},
|
||||
{k_aligned}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const device {outmasktype}* out_mask [[buffer(10)]],
|
||||
const device {opmasktype}* lhs_mask [[buffer(11)]],
|
||||
const device {opmasktype}* rhs_mask [[buffer(12)]],
|
||||
const constant int* mask_strides [[buffer(13)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk<
|
||||
{itype},
|
||||
{otype},
|
||||
{bm},
|
||||
{bn},
|
||||
{bk},
|
||||
{wm},
|
||||
{wn},
|
||||
{trans_a},
|
||||
{trans_b},
|
||||
{mn_aligned},
|
||||
{k_aligned}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {otype}* C [[buffer(2)]],
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk_accum<{atype}, {otype}>(
|
||||
const device {atype}* C_split [[buffer(0)]],
|
||||
device {otype}* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk_accum_axpby<{atype}, {otype}>(
|
||||
const device {atype}* C_split [[buffer(0)]],
|
||||
device {otype}* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
const device {otype}* C [[buffer(5)]],
|
||||
const constant int& ldc [[buffer(6)]],
|
||||
const constant int& fdc [[buffer(7)]],
|
||||
const constant float& alpha [[buffer(8)]],
|
||||
const constant float& beta [[buffer(9)]],
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
)";
|
80
mlx/backend/metal/jit/ternary.h
Normal file
80
mlx/backend/metal/jit/ternary.h
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view ternary_kernels = R"(
|
||||
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1_{0}")]] [[kernel]] void
|
||||
ternary_g_nd1<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t& a_strides,
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2_{0}")]] [[kernel]] void
|
||||
ternary_g_nd2<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const size_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3_{0}")]] [[kernel]] void
|
||||
ternary_g_nd3<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const size_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||
ternary_g_nd<{1}, {2}, 4>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
constant const size_t c_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||
ternary_g_nd<{1}, {2}, 5>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
constant const size_t c_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
16
mlx/backend/metal/jit/unary.h
Normal file
16
mlx/backend/metal/jit/unary.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view unary_kernels = R"(
|
||||
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
|
||||
device const {1}* in,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
|
||||
device const {1}* in,
|
||||
device {1}* out,
|
||||
device const int* in_shape,
|
||||
device const size_t* in_strides,
|
||||
device const int& ndim,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
486
mlx/backend/metal/jit_kernels.cpp
Normal file
486
mlx/backend/metal/jit_kernels.cpp
Normal file
@@ -0,0 +1,486 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/binary.h"
|
||||
#include "mlx/backend/metal/jit/binary_two.h"
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/sort.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
#include "mlx/backend/metal/jit/ternary.h"
|
||||
#include "mlx/backend/metal/jit/unary.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string op_name(const array& arr) {
|
||||
std::ostringstream op_t;
|
||||
arr.primitive().print(op_t);
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_arange_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source
|
||||
<< metal::utils() << metal::arange()
|
||||
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
|
||||
<< fmt::format(
|
||||
unary_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(2);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops() << metal::binary()
|
||||
<< fmt::format(
|
||||
binary_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(2);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops()
|
||||
<< metal::binary_two()
|
||||
<< fmt::format(
|
||||
binary_two_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary()
|
||||
<< fmt::format(
|
||||
ternary_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::copy()
|
||||
<< fmt::format(
|
||||
copy_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool precise,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::softmax()
|
||||
<< fmt::format(
|
||||
softmax_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
get_type_string(precise ? float32 : out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::scan()
|
||||
<< fmt::format(
|
||||
scan_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out),
|
||||
inclusive,
|
||||
reverse);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
int bn,
|
||||
int tn) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
block_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& idx,
|
||||
int bn,
|
||||
int tn) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
multiblock_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(idx.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
auto lib = d.get_library(kernel_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
<< fmt::format(
|
||||
reduce_init_kernels,
|
||||
kernel_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(kernel_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
|
||||
<< fmt::format(
|
||||
non_atomic ? reduce_non_atomic_kernels
|
||||
: reduce_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_fused()
|
||||
<< fmt::format(
|
||||
steel_gemm_fused_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
steel_gemm_splitk_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool axbpy) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
|
||||
: steel_gemm_splitk_accum_kernels,
|
||||
"name"_a = lib_name,
|
||||
"atype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto out_mask_type = mask_out.has_value()
|
||||
? get_type_string((*mask_out).dtype())
|
||||
: "nomask_t";
|
||||
auto op_mask_type =
|
||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_masked()
|
||||
<< fmt::format(
|
||||
steel_gemm_masked_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"outmasktype"_a = out_mask_type,
|
||||
"opmasktype"_a = op_mask_type,
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
int n_channel_specialization,
|
||||
bool small_filter) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
||||
<< fmt::format(
|
||||
steel_conv_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"n_channels"_a = n_channel_specialization,
|
||||
"small_filter"_a = small_filter);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv()
|
||||
<< metal::steel_conv_general()
|
||||
<< fmt::format(
|
||||
steel_conv_general_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user