Compare commits

..

4 Commits

Author SHA1 Message Date
Awni Hannun
0dbe80a024 try again with checkpointed classes 2024-03-06 10:38:04 -08:00
Awni Hannun
a5827d0384 docs for checkpoint + a few more tests 2024-03-06 10:38:04 -08:00
Awni Hannun
1368bce280 fix tests and add setter attributes 2024-03-06 10:38:04 -08:00
Awni Hannun
8918a437bb checkpoint module's __call__ 2024-03-06 10:38:04 -08:00
454 changed files with 20379 additions and 61391 deletions

View File

@@ -13,62 +13,8 @@ parameters:
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install
command: |
brew install python@3.9
brew install doxygen
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test:
docker:
- image: cimg/python:3.9
@@ -85,59 +31,53 @@ jobs:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
- run:
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
- run:
name: Generate package stubs
command: |
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
python3 setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3 -m pip install .
- run:
name: Build CPP only
command: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j
- run:
name: Run CPP tests
command: ./build/tests/tests
mac_build_and_test:
parameters:
xcode_version:
type: string
default: "15.2.0"
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
xcode: "15.2.0"
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.9
brew install openmpi
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
pip install torch
pip install tensorflow
@@ -146,60 +86,35 @@ jobs:
name: Install Python package
command: |
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
- run:
name: Build example extension
command: |
source env/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3.11 -m pip install .
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
mkdir -p build && cd build && cmake .. && make -j
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
- run:
name: Build small binary
command: |
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
DEVICE=cpu ./build/tests/tests
build_release:
parameters:
@@ -214,20 +129,20 @@ jobs:
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install twine
pip install build
@@ -236,20 +151,19 @@ jobs:
command: |
source env/bin/activate
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build -w
- when:
condition: << parameters.build_env >>
@@ -262,7 +176,7 @@ jobs:
- store_artifacts:
path: dist/
build_linux_release:
build_linux_test_release:
parameters:
python_version:
type: string
@@ -291,28 +205,22 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install nanobind==2.2.0
pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts:
path: wheelhouse/
@@ -327,12 +235,8 @@ workflows:
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- mac_build_and_test
- linux_build_and_test
- build_documentation
build_pypi_release:
when:
@@ -349,17 +253,9 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
prb:
when:
matches:
@@ -372,9 +268,6 @@ workflows:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
@@ -386,8 +279,8 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"]
weekly_build:
when:
and:
@@ -397,17 +290,17 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >>
- << pipeline.parameters.test_release >>
jobs:
- build_linux_release:
- build_linux_test_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -17,4 +17,4 @@ jobs:
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files
pre-commit run --all-files

View File

@@ -1,11 +1,11 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8
rev: v17.0.6
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
@@ -14,7 +14,3 @@ repos:
- id: isort
args:
- --profile=black
- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.13
hooks:
- id: cmake-format

View File

@@ -7,19 +7,14 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>

View File

@@ -1,24 +0,0 @@
cff-version: 1.2.0
title: mlx
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Awni
family-names: Hannun
affiliation: Apple
- given-names: Jagrit
family-names: Digani
affiliation: Apple
- given-names: Angelos
family-names: Katharopoulos
affiliation: Apple
- given-names: Ronan
family-names: Collobert
affiliation: Apple
repository-code: 'https://github.com/ml-explore'
abstract: >-
MLX: efficient and flexible machine learning on Apple
silicon
license: MIT

View File

@@ -15,43 +15,35 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.21.0)
set(MLX_VERSION 0.5.1)
endif()
# --------------------- Processor tests -------------------------
message(
STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC)
message(
FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
)
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
set(MLX_BUILD_ARM OFF)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message(WARNING
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, "
" make sure you are building for arm64.")
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif()
@@ -63,199 +55,150 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
set(FOUNDATION_LIB "-framework Foundation")
set(QUARTZ_LIB "-framework QuartzCore")
if (MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB)
if (MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif(MLX_BUILD_METAL)
elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if(MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
# Throw an error if xcrun not found
execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION
COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
else()
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
endif()
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
execute_process(
COMMAND
zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
mlx PUBLIC
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>
)
target_link_libraries(
mlx
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
endif()
if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if(ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if(NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
if(NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
if(WIN32)
find_package(dlfcn-win32 REQUIRED)
message(STATUS "dlfcn-win32 lib " ${dlfcn-win32_LIBRARIES})
message(STATUS "dlfcn-win32 include " ${dlfcn-win32_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${dlfcn-win32_LIBRARIES})
endif()
endif()
find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
#set(BLA_VENDOR Generic)
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>)
mlx
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>
)
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)
if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()
if(MLX_BUILD_TESTS)
if (MLX_BUILD_TESTS)
include(CTest)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
endif()
if(MLX_BUILD_EXAMPLES)
if (MLX_BUILD_EXAMPLES)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
endif()
if(MLX_BUILD_BENCHMARKS)
if (MLX_BUILD_BENCHMARKS)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif()
# ----------------------------- Installation -----------------------------
include(GNUInstallDirs)
# Install library
install(
TARGETS mlx
EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
TARGETS mlx
EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
# Install headers
install(
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers
FILES_MATCHING
PATTERN "*.h"
PATTERN "backend/metal/kernels.h" EXCLUDE)
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers
FILES_MATCHING PATTERN "*.h"
)
# Install metal dependencies
if(MLX_BUILD_METAL)
if (MLX_BUILD_METAL)
# Install metal cpp
install(
DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source)
DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source
)
endif()
@@ -267,24 +210,31 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
install(
EXPORT MLXTargets
FILE MLXTargets.cmake
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
include(CMakePackageConfigHelpers)
write_basic_package_version_file(
${MLX_CMAKE_BUILD_VERSION_CONFIG}
COMPATIBILITY SameMajorVersion
VERSION ${MLX_VERSION})
VERSION ${MLX_VERSION}
)
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
${MLX_CMAKE_BUILD_CONFIG}
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
NO_CHECK_REQUIRED_COMPONENTS_MACRO
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
MLX_CMAKE_INSTALL_MODULE_DIR)
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
)
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
install(
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
install(DIRECTORY ${CMAKE_MODULE_PATH}/
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
install(
DIRECTORY ${CMAKE_MODULE_PATH}/
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)

View File

@@ -6,7 +6,7 @@
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning on Apple silicon,
MLX is an array framework for machine learning research on Apple silicon,
brought to you by Apple machine learning research.
Some key features of MLX include:
@@ -88,13 +88,13 @@ for more information on building the C++ and Python APIs from source.
## Contributing
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
on contributing to MLX. See the
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
information on building from source, and running tests.
We are grateful for all of [our
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX and wish to be acknowledged, please add your name to the list in your
pull request.

View File

@@ -17,13 +17,14 @@
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
<< std::flush << std::setprecision(5) \
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " \
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
template <typename F, typename... Args>
double time_fn(F fn, Args&&... args) {
double time_fn(F fn, Args... args) {
// warmup
for (int i = 0; i < 5; ++i) {
eval(fn(std::forward<Args>(args)...));

View File

@@ -144,13 +144,6 @@ def reduction(op, axis, x):
mx.eval(ys)
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
def softmax(axis, x):
ys = []
for i in range(100):
@@ -512,8 +505,5 @@ if __name__ == "__main__":
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError("Unknown benchmark")

View File

@@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
def mish(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
y = torch.nn.functional.mish(y)
return torch.nn.functional.mish(y)
sync_if_needed(x)
@@ -283,14 +283,6 @@ def topk(axis, x):
sync_if_needed(x)
@torch.no_grad()
def step_function(x):
y = x
for i in range(100):
y = torch.where(y < 0, 0, 1)
sync_if_needed(x)
@torch.no_grad()
def selu(x):
y = x
@@ -454,11 +446,5 @@ if __name__ == "__main__":
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
raise ValueError("Unknown benchmark")

View File

@@ -16,9 +16,7 @@ def run_or_raise(*args, **kwargs):
result = run(*args, capture_output=True, **kwargs)
return float(result.stdout)
except ValueError:
raise ValueError(
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
)
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
def compare(args):

View File

@@ -9,6 +9,7 @@ from time_utils import time_fn
def bench_gelu():
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
@@ -50,6 +51,7 @@ def bench_gelu():
def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias)

View File

@@ -1,123 +0,0 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_1D(strides=1, padding=0, groups=1):
def mx_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_1D
def make_pt_conv_1D(strides=1, padding=0, groups=1):
@torch.no_grad()
def pt_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_1D
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
scale = 1.0 / math.sqrt(wH * C)
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_1D(strides, padding, groups)
f_pt = make_pt_conv_1D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv1d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 5, 32, 1, 2, 1),
(4, 32, 32, 5, 32, 1, 2, 2),
(4, 32, 32, 5, 32, 1, 2, 4),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 16),
(4, 32, 32, 5, 32, 1, 2, 32),
(4, 32, 256, 5, 512, 1, 2, 2),
(4, 32, 256, 5, 512, 1, 2, 128),
(4, 32, 256, 5, 512, 1, 2, 256),
)
for dtype in dtypes:
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
for N, iH, C, wH, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, iH, C, wH, O, strides, padding, np_dtype, groups
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,127 +0,0 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,143 +0,0 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal([10, 256, 256, 3])
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(10, 3, 256, 256, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 20
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@@ -1,129 +0,0 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,110 +0,0 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,143 +0,0 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal(shape)
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(*shape, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 10
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@@ -1,116 +0,0 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose3d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -28,11 +28,11 @@ def bench(f, a, b):
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
y = mx.conv2d(a, b, stride=strides, padding=padding)
ys.append(y)
mx.eval(ys)
return ys
@@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
y = torch.conv2d(a, b, stride=strides, padding=padding)
ys.append(y)
torch.mps.synchronize()
return ys
@@ -53,12 +53,11 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
@@ -68,15 +67,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
f_mx = make_mx_conv_2D(strides, padding)
f_pt = make_pt_conv_2D(strides, padding)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
@@ -85,7 +84,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
@@ -96,40 +95,35 @@ if __name__ == "__main__":
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
for N, H, W, C, kH, kW, O, strides, padding in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
N, H, W, C, kH, kW, O, strides, padding, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,135 +0,0 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,66 +0,0 @@
# Copyright © 2024 Apple Inc.
"""
Run with:
mpirun -n 2 python /path/to/distributed_bench.py
"""
import time
import mlx.core as mx
def time_fn(fn, *args, **kwargs):
msg = kwargs.pop("msg", None)
world = mx.distributed.init()
if world.rank() == 0:
if msg:
print(f"Timing {msg} ...", end=" ")
else:
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
mx.eval(fn(*args, **kwargs))
num_iters = 100
tic = time.perf_counter()
for _ in range(num_iters):
x = mx.eval(fn(*args, **kwargs))
toc = time.perf_counter()
msec = 1e3 * (toc - tic) / num_iters
if world.rank() == 0:
print(f"{msec:.5f} msec")
def time_all_sum():
shape = (4096,)
x = mx.random.uniform(shape=shape)
mx.eval(x)
def sine(x):
for _ in range(20):
x = mx.sin(x)
return x
time_fn(sine, x)
def all_sum_plain(x):
for _ in range(20):
x = mx.distributed.all_sum(x)
return x
time_fn(all_sum_plain, x)
def all_sum_with_sine(x):
for _ in range(20):
x = mx.sin(x)
x = mx.distributed.all_sum(x)
return x
time_fn(all_sum_with_sine, x)
if __name__ == "__main__":
time_all_sum()

View File

@@ -1,84 +0,0 @@
# Copyright © 2024 Apple Inc.
import time
import mlx.core as mx
import numpy as np
def timeit(fn, its=100, args=[]):
for _ in range(5):
fn(*args)
tic = time.perf_counter()
for _ in range(its):
fn(*args)
toc = time.perf_counter()
return 1e3 * (toc - tic) / its
def time_little_einsum_path():
subscripts = "ik,kj->ij"
x = mx.ones((32, 32))
y = mx.ones((32, 32))
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
x = np.array(x)
y = np.array(y)
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
print("Timing little einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_big_einsum_path():
chars = list("abcdefgh")
char_to_dim = {c: v for v, c in enumerate(chars)}
num_inputs = 10
inputs = []
subscripts = []
for _ in range(num_inputs):
subscript = np.random.choice(chars, size=5, replace=False).tolist()
subscripts.append("".join(subscript))
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
subscripts = ",".join(subscripts)
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
inputs = [mx.array(x) for x in inputs]
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
print("Timing big einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_attention():
def regular_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
scores = mx.softmax(scores, axis=-1)
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
mx.eval(output)
def einsum_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
scores = mx.softmax(scores, axis=-1)
output = mx.einsum("ijtu,iujk->itjk", scores, values)
mx.eval(output)
x = mx.random.uniform(shape=(8, 512, 32, 128))
regular_time = timeit(regular_attention, args=(x,))
ein_time = timeit(einsum_attention, args=(x,))
print("Timing einsum attention...")
print(f"Regular ... {regular_time:.3f} ms")
print(f"Einsum ... {ein_time:.3f} ms")
if __name__ == "__main__":
time_little_einsum_path()
time_big_einsum_path()
time_attention()

View File

@@ -1,118 +0,0 @@
# Copyright © 2024 Apple Inc.
import matplotlib
import mlx.core as mx
import numpy as np
import sympy
import torch
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def bandwidth_gb(runtime_ms, system_size):
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
bytes_per_gb = 1e9
ms_per_s = 1e3
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
def fft_mlx(x):
if dim == 1:
out = mx.fft.fft(x)
elif dim == 2:
out = mx.fft.fft2(x)
mx.eval(out)
return out
def fft_mps(x):
if dim == 1:
out = torch.fft.fft(x)
elif dim == 2:
out = torch.fft.fft2(x)
torch.mps.synchronize()
return out
bandwidths = []
for n in fft_sizes:
batch_size = system_size // n**dim
shape = [batch_size] + [n for _ in range(dim)]
if backend == "mlx":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = mx.array(x_np)
mx.eval(x)
fft = fft_mlx
elif backend == "mps":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = torch.tensor(x_np, device="mps")
torch.mps.synchronize()
fft = fft_mps
else:
raise NotImplementedError()
runtime_ms = measure_runtime(fft, x=x)
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
print(n, bandwidth)
bandwidths.append(bandwidth)
return np.array(bandwidths)
def time_fft():
x = np.array(range(2, 512))
system_size = int(2**26)
print("MLX GPU")
with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
print("MPS GPU")
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
print("CPU")
system_size = int(2**20)
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
x = np.array(x)
all_indices = x - x[0]
radix_2to13 = (
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
)
bluesteins = (
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
)
for indices, name in [
(all_indices, "All"),
(radix_2to13, "Radix 2-13"),
(bluesteins, "Bluestein's"),
]:
# plot bandwidths
print(name)
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
plt.title(f"MLX FFT Benchmark -- {name}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"{name}.png")
plt.clf()
av_gpu_bandwidth = np.mean(gpu_bandwidths)
av_mps_bandwidth = np.mean(mps_bandwidths)
av_cpu_bandwidth = np.mean(cpu_bandwidths)
print("Average bandwidths:")
print("GPU:", av_gpu_bandwidth)
print("MPS:", av_mps_bandwidth)
print("CPU:", av_cpu_bandwidth)
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
print("Percent MLX faster than MPS: ", portion_faster * 100)
if __name__ == "__main__":
time_fft()

View File

@@ -1,70 +0,0 @@
import argparse
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def had(x):
y = mx.hadamard_transform(x)
mx.eval(y)
def copy(x):
y = x + 1.0
mx.eval(y)
def run(dtype):
system_size = 2**26
outputs = {}
for test_fn in (had, copy):
for m in [1, 12, 20, 28]:
if test_fn == copy:
key = "copy"
elif m == 1:
key = "had_2^k"
else:
key = "had_m*2^k"
outputs.setdefault(key, {})
for k in range(7, 14):
n = m * 2**k
if n > 2**15:
continue
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
x = mx.array(x_np)
runtime_ms = measure_runtime(test_fn, x=x)
bytes_per_gb = 1e9
ms_per_s = 1e3
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
bandwidth_gb = (
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
)
print(n, bandwidth_gb)
outputs[key][n] = bandwidth_gb
colors = {
"copy": "black",
"had_2^k": "steelblue",
"had_m*2^k": "skyblue",
}
for key, output in outputs.items():
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"bench_{dtype.__name__}.png")
plt.clf()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
dtype = np.float16 if args.fp16 else np.float32
run(dtype)

View File

@@ -1,41 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def layer_norm(x, w, b, eps):
ot = x.dtype
x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True)
return (x - mu) * mx.rsqrt(v + eps) * w + b
def time_layer_norm():
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
if __name__ == "__main__":
time_layer_norm()

View File

@@ -1,39 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def rms_norm(x, w, eps):
ot = x.dtype
x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (x * n).astype(ot) * w
def time_rms_norm():
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1))
g2 = mx.grad(f2, argnums=(0, 1))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x, w):
gx, gw = x, w
for _ in range(32):
gx, gw = g(gx, gw, y)
return gx, gw
time_fn(rms_norm_loop, g1, x, w)
time_fn(rms_norm_loop, g2, x, w)
time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w)
if __name__ == "__main__":
time_rms_norm()

View File

@@ -6,21 +6,21 @@ from time_utils import time_fn
def time_rope():
rope = nn.RoPE(64)
rope = nn.RoPE(4096)
# vec
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x, offset=100)
x = rope(x)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):

View File

@@ -9,7 +9,7 @@ from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx):
dst[tuple(idx)] = x
dst[*idx] = x
mx.eval(dst)
idx = []
@@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def scatter(dst, x, idx, device):
dst[tuple(idx)] = x
def gather(dst, x, idx, device):
dst[*idx] = x
if device == torch.device("mps"):
torch.mps.synchronize()
@@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
@@ -54,7 +54,7 @@ if __name__ == "__main__":
(100_000, 64),
(1_000_000, 64),
(100_000,),
(200_000,),
(2_000_00,),
(20_000_000,),
(10000, 64),
(100, 64),
@@ -91,6 +91,6 @@ if __name__ == "__main__":
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@@ -1,189 +0,0 @@
# Copyright © 2024 Apple Inc.
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 5
N_iter_bench = 40
N_iter_func = 8
def bench(f, *args):
for i in range(N_warmup):
f(*args)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(*args)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def mlx_sdpa_fused_inner(q, k, v, scale):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
else:
scores = mx.softmax(scores, axis=-1)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def mlx_spda_unfused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def mlx_spda_fused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
)
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
scale = math.sqrt(1.0 / head_dim)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float16", "float32")[:1]
transposes = (False,)
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -1,94 +0,0 @@
import mlx.core as mx
import numpy as np
from mlx.utils import tree_map
from time_utils import time_fn
L = 32768
H = 32
H_k = H // 4
D = 128
dtype = mx.float16
bits = 8
loops = 20
def attention(q, k, v):
for _ in range(loops):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
ke = k[:, :, None, :, :]
ve = v[:, :, None, :, :]
s = q @ ke.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
q = p @ ve
q = q.reshape(B, Hq, L, D)
return q
def sdpa(q, k, v):
for _ in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
return q
def quant_sdpa(q, k, v, bits=4):
for _ in range(loops):
q = mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=1.0, mask=None, bits=bits
)
return q
def quant_attention(q, k, v, bits=4):
for _ in range(loops):
B, Hq, L, D = q.shape
Hk = k[0].shape[1]
q = q.reshape((B, Hk, Hq // Hk, L, D))
ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits)
scores = mx.softmax(scores, axis=-1)
q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits)
q = q.reshape((B, Hq, L, D))
return q
def time_self_attention_primitives(q, k, v):
time_fn(attention, q, k, v)
def time_self_attention_sdpa(q, k, v):
time_fn(sdpa, q, k, v)
def time_self_attention_quant_sdpa(q, k, v, bits=4):
time_fn(quant_sdpa, q, k, v, bits)
def time_self_attention_quant_primitives(q, k, v, bits=4):
time_fn(quant_attention, q, k, v, bits)
if __name__ == "__main__":
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype)
k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
mx.eval(q, k, v)
k_quant = mx.quantize(k, bits=bits)
v_quant = mx.quantize(v, bits=bits)
mx.eval(k_quant, v_quant)
k = mx.dequantize(*k_quant, bits=bits)
v = mx.dequantize(*v_quant, bits=bits)
time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)

View File

@@ -1,41 +1,56 @@
include(CMakeParseArguments)
# ##############################################################################
###############################################################################
# Build metal library
#
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
#
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers)
# Args:
# TARGET: Custom target to be added for the metal library
# TITLE: Name of the .metallib
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
# SOURCES: List of source files
# INCLUDE_DIRS: List of include dirs
# DEPS: List of dependency files (like headers)
#
macro(mlx_build_metallib)
# Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cmake_parse_arguments(
MTLLIB
""
"${oneValueArgs}"
"${multiValueArgs}"
${ARGN}
)
# Set output
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
# Prepare metallib build command
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND
xcrun -sdk macosx metal
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
COMMAND xcrun -sdk macosx metal
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS}
${MTLLIB_SOURCES}
-o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
COMMAND_EXPAND_LISTS
COMMENT "Building ${MTLLIB_TITLE}.metallib"
VERBATIM)
VERBATIM
)
# Add metallib custom target
add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
add_custom_target(
${MTLLIB_TARGET}
DEPENDS
${MTLLIB_BUILD_TARGET}
)
endmacro(mlx_build_metallib)
endmacro(mlx_build_metallib)

View File

@@ -1,50 +0,0 @@
################################################################################
# Primary project setup. #
################################################################################
PROJECT_NAME = "MLX"
OUTPUT_DIRECTORY = build
XML_OUTPUT = xml
HTML_OUTPUT = html
STRIP_FROM_PATH = ../
INPUT = ../mlx
FILE_PATTERNS = *.h
EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = YES
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES
################################################################################
# Doxygen preprocessor / parser control. #
################################################################################
ENABLE_PREPROCESSING = YES
MACRO_EXPANSION = YES
EXPAND_ONLY_PREDEF = NO
SKIP_FUNCTION_MACROS = NO
################################################################################
# Compound extraction control. #
################################################################################
EXTRACT_ALL = YES
EXTRACT_PACKAGE = YES
EXTRACT_STATIC = YES
CASE_SENSE_NAMES = NO
################################################################################
# Docstring control / customization. #
################################################################################
JAVADOC_AUTOBRIEF = YES
################################################################################
# Warning suppression. #
################################################################################
QUIET = YES
WARN_IF_UNDOCUMENTED = NO

View File

@@ -2,16 +2,12 @@
### Setup (do once)
Install Doxygen:
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
for example with `conda`:
```
brew install doxygen
```
Install Python packages:
```
pip install -r requirements.txt
conda install sphinx
pip install sphinx-book-theme
```
### Build
@@ -19,7 +15,7 @@ pip install -r requirements.txt
Build the docs from `mlx/docs/`
```
doxygen && make html
make html
```
View the docs by running a server in `mlx/docs/build/html/`:

View File

@@ -1,4 +0,0 @@
sphinx
breathe
sphinx-book-theme
mlx

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 746 KiB

View File

@@ -1,20 +0,0 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != "__init__" %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}

View File

@@ -22,7 +22,6 @@ extensions = [
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"breathe",
]
python_use_unqualified_type_names = True
@@ -30,20 +29,16 @@ autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"https://docs.python.org/3": None,
"https://numpy.org/doc/stable/": None,
}
breathe_projects = {"mlx": "../build/xml"}
breathe_default_project = "mlx"
templates_path = ["_templates"]
html_static_path = ["_static"]
source_suffix = ".rst"
main_doc = "index"
master_doc = "index"
highlight_language = "python"
pygments_style = "sphinx"
add_module_names = False
# -- Options for HTML output -------------------------------------------------
@@ -60,39 +55,7 @@ html_theme_options = {
},
}
html_favicon = html_theme_options["logo"]["image_light"]
# -- Options for HTMLHelp output ---------------------------------------------
htmlhelp_basename = "mlx_doc"
def setup(app):
from sphinx.util import inspect
wrapped_isfunc = inspect.isfunction
def isfunc(obj):
type_name = str(type(obj))
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
return True
return wrapped_isfunc(obj)
inspect.isfunction = isfunc
# -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
latex_elements = {
"preamble": r"""
\usepackage{enumitem}
\setlistdepth{5}
\setlist[itemize,1]{label=$\bullet$}
\setlist[itemize,2]{label=$\bullet$}
\setlist[itemize,3]{label=$\bullet$}
\setlist[itemize,4]{label=$\bullet$}
\setlist[itemize,5]{label=$\bullet$}
\renewlist{itemize}{itemize}{5}
""",
}

View File

@@ -3,5 +3,4 @@
Operations
==========
.. doxygengroup:: ops
:content-only:

View File

@@ -1,427 +0,0 @@
.. _custom_metal_kernels:
Custom Metal Kernels
====================
MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
.. note::
We are only required to pass the body of the Metal kernel in ``source``.
The full function signature will be generated using:
* The shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
in ``source``.
* The list of ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``.
* Template parameters passed using ``template``
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
These will be added as function arguments.
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
Putting this all together, the generated function signature for ``myexp`` is as follows:
.. code-block:: cpp
template <typename T>
[[kernel]] void custom_kernel_myexp_float(
const device float16_t* inp [[buffer(0)]],
device float16_t* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
}
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
Using Shape/Strides
-------------------
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
when indexing.
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
input array ``a`` if any are present in ``source``.
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
a = a[::2]
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Complex Example
-----------------------------
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
We'll start with the following MLX implementation using standard ops:
.. code-block:: python
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel:
.. code-block:: python
@mx.custom_function
def grid_sample(x, grid):
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
For a reasonably sized input such as:
.. code-block:: python
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement:
``55.7ms -> 6.7ms => 8x speed up``
Grid Sample VJP
---------------
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra ``mx.fast.metal_kernel`` features:
* ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
* ``atomic_outputs=True``
Designate all of the kernel outputs as ``atomic`` in the function signature.
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
We can then implement the backwards pass as follows:
.. code-block:: python
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
assert D == 2, "Last dim of `grid` must be size 2."
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
T gix_mult = W / 2;
T giy_mult = H / 2;
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp:
``676.4ms -> 16.7ms => 40x speed up``

View File

@@ -1,16 +1,24 @@
Custom Extensions in MLX
========================
Developer Documentation
=======================
You can extend MLX with custom operations on the CPU or GPU. This guide
explains how to do that with a simple example.
MLX provides a open and flexible backend to which users may add operations
and specialized implementations without much hassle. While the library supplies
efficient operations that can be used and composed for any number of
applications, there may arise cases where new functionalities or highly
optimized implementations are needed. For such cases, you may design and
implement your own operations that link to and build on top of :mod:`mlx.core`.
We will introduce the inner-workings of MLX and go over a simple example to
learn the steps involved in adding new operations to MLX with your own CPU
and GPU implementations.
Introducing the Example
-----------------------
Let's say you would like an operation that takes in two arrays, ``x`` and
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
and then adds them together to get the result ``z = alpha * x + beta * y``.
You can do that in MLX directly:
Let's say that you would like an operation that takes in two arrays,
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
respectively, and then adds them together to get the result
``z = alpha * x + beta * y``. Well, you can very easily do that by just
writing out a function as follows:
.. code-block:: python
@@ -19,35 +27,44 @@ You can do that in MLX directly:
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
This function performs that operation while leaving the implementation and
function transformations to MLX.
This function performs that operation while leaving the implementations and
differentiation to MLX.
However you may need to customize the underlying implementation, perhaps to
make it faster or for custom differentiation. In this tutorial we will go
through adding custom extensions. It will cover:
However, you work with vector math libraries often and realize that the
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
You would really like the part of your applications that does this operation
on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how to add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.
* The structure of the MLX library.
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
* Implementing a GPU operation using metal.
* Adding the ``vjp`` and ``jvp`` function transformation.
* Building a custom extension and binding it to python.
Well, what a coincidence! You are in the right place. Over the course of this
example, we will learn:
* The structure of the MLX library from the frontend API to the backend implementations.
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
* How to implement your own GPU implementation using metal.
* How to add your own ``vjp`` and ``jvp``.
* How to build your implementations, link them to MLX, and bind them to python.
Operations and Primitives
-------------------------
Operations in MLX build the computation graph. Primitives provide the rules for
evaluating and transforming the graph. Let's start by discussing operations in
more detail.
In one sentence, operations in MLX build the computation graph, and primitives
provide the rules for evaluation and transformations of said graph. Let's start
by discussing operations in more detail.
Operations
^^^^^^^^^^^
Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
Operations are the frontend functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
operations in the Python API (:ref:`ops`).
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++:
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
C++ API:
.. code-block:: C++
@@ -66,7 +83,10 @@ C++:
StreamOrDevice s = {} // Stream on which to schedule the operation
);
The simplest way to this operation is in terms of existing operations:
This operation itself can call other operations within it if needed. So, the
simplest way to go about implementing this operation would be do so in terms
of existing operations.
.. code-block:: C++
@@ -80,23 +100,25 @@ The simplest way to this operation is in terms of existing operations:
// Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s);
// Add and return
return add(ax, by, s);
}
The operations themselves do not contain the implementations that act on the
data, nor do they contain the rules of transformations. Rather, they are an
easy to use interface that use :class:`Primitive` building blocks.
However, as we discussed earlier, this is not our goal. The operations themselves
do not contain the implementations that act on the data, nor do they contain the
rules of transformations. Rather, they are an easy to use interface that build
on top of the building blocks we call :class:`Primitive`.
Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
more concrete:
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create an output given a set of input :class:`array` . Further,
a :class:`Primitive` is a class that contains rules on how it is evaluated
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
``jvp``. These words on their own can be a bit abstract, so lets take a step
back and go to our example to give ourselves a more concrete image.
.. code-block:: C++
@@ -112,15 +134,11 @@ more concrete:
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
array jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
@@ -129,8 +147,7 @@ more concrete:
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
const std::vector<int>& argnums) override;
/**
* The primitive must know how to vectorize itself across
@@ -138,7 +155,7 @@ more concrete:
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
@@ -158,22 +175,22 @@ more concrete:
void eval(const std::vector<array>& inputs, array& out);
};
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
implementations of how the output array is produced given the inputs through
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
:meth:`Axpby::vmap`.
The :class:`Axpby` class derives from the base :class:`Primitive` class and
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
``beta`` as parameters. It then provides implementations of how the array ``out``
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
Using the Primitive
^^^^^^^^^^^^^^^^^^^
Using the Primitives
^^^^^^^^^^^^^^^^^^^^^
Operations can use this :class:`Primitive` to add a new :class:`array` to the
computation graph. An :class:`array` can be constructed by providing its data
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
inputs that are passed to the primitive.
Operations can use this :class:`Primitive` to add a new :class:`array` to
the computation graph. An :class:`array` can be constructed by providing its
data type, shape, the :class:`Primitive` that computes it, and the
:class:`array` inputs that are passed to the primitive.
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
.. code-block:: C++
@@ -206,7 +223,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
std::make_unique<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
@@ -221,26 +238,27 @@ This operation now handles the following:
Implementing the Primitive
--------------------------
No computation happens when we call the operation alone. The operation only
builds the computation graph. When we evaluate the output array, MLX schedules
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
No computation happens when we call the operation alone. In effect, the
operation only builds the computation graph. When we evaluate the output
array, MLX schedules the execution of the computation graph, and calls
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
stream/device specified by the user.
.. warning::
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
no memory has been allocated for the output array. It falls on the implementation
of these functions to allocate memory as needed.
of these functions to allocate memory as needed
Implementing the CPU Back-end
Implementing the CPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by implementing a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Let's start by trying to implement a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
point-wise. This is captured in the templated function :meth:`axpby_impl`.
Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
pointwise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++
@@ -278,19 +296,19 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
}
}
Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
``complex64``. We throw an error if we encounter an unexpected type.
Now, we would like our implementation to be able to do this pointwise operation
for all incoming floating point arrays. Accordingly, we add dispatches for
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
if we encounter an unexpected type.
.. code-block:: C++
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
@@ -303,26 +321,28 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"[Axpby] Only supports floating point types.");
"Axpby is only supported for floating point types.");
}
}
This is good as a fallback implementation. We can use the ``axpby`` routine
provided by the Accelerate_ framework for a faster implementation in certain
cases:
We have a fallback implementation! Now, to do what we are really here to do.
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
framework? Well, there are 3 complications to keep in mind:
#. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only use it for ``float32`` types.
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
elements have fixed strides between them. We only direct to Accelerate
if both ``x`` and ``y`` are row contiguous or column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
MLX expects to write the output to a new array. We must copy the elements
of ``y`` into the output and use that as an input to ``axpby``.
floats. We can only direct to it for ``float32`` types
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
have fixed strides between them. Possibly due to broadcasts and transposes,
we aren't guaranteed that the inputs fit this requirement. We can
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
MLX expects to write out the answer to a new array. We must copy the elements
of ``y`` into the output array and use that as an input to ``axpby``
Let's write an implementation that uses Accelerate in the right conditions.
It allocates data for the output, copies ``y`` into it, and then calls the
:func:`catlas_saxpby` from accelerate.
Let's write out an implementation that uses Accelerate in the right conditions.
It must simply allocate data for the output, copy elements of ``y`` into it,
and then call the :meth:`catlas_saxpby` from accelerate.
.. code-block:: C++
@@ -336,7 +356,17 @@ It allocates data for the output, copies ``y`` into it, and then calls the
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
@@ -359,20 +389,18 @@ It allocates data for the output, copies ``y`` into it, and then calls the
/* INCY = */ 1);
}
For inputs that do not fit the criteria for accelerate, we fall back to
:meth:`Axpby::eval`. With this in mind, let's finish our
:meth:`Axpby::eval_cpu`.
Great! But what about the inputs that do not fit the criteria for accelerate?
Luckily, we can always just direct back to :meth:`Axpby::eval`.
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
.. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
@@ -382,33 +410,35 @@ For inputs that do not fit the criteria for accelerate, we fall back to
return;
}
// Fall back to common back-end if specializations are not available
eval(inputs, outputs);
// Fall back to common backend if specializations are not available
eval(inputs, out);
}
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
We have now hit a milestone! Just this much is enough to run the operation
:meth:`axpby` on a CPU stream!
Implementing the GPU Back-end
If you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
Implementing the GPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Apple silicon devices address their GPUs using the Metal_ shading language, and
GPU kernels in MLX are written using Metal.
Apple silicon devices address their GPUs using the Metal_ shading language, and
all GPU kernels in MLX are written using metal.
.. note::
Here are some helpful resources if you are new to Metal:
Here are some helpful resources if you are new to metal!
* A walkthrough of the metal compute pipeline: `Metal Example`_
* Documentation for metal shading language: `Metal Specification`_
* Using metal from C++: `Metal-cpp`_
Let's keep the GPU kernel simple. We will launch exactly as many threads as
there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the point-wise operation, and update its assigned
element in the output.
Let's keep the GPU algorithm simple. We will launch exactly as many threads
as there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
element in the output.
.. code-block:: C++
@@ -427,14 +457,15 @@ element in the output.
// Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output
out[index] =
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
We then need to instantiate this template for all floating point types and give
each instantiation a unique host name so we can identify it.
each instantiation a unique host name so we can identify the right kernel for
each data type.
.. code-block:: C++
@@ -457,21 +488,29 @@ each instantiation a unique host name so we can identify it.
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
will see later in :ref:`Building with CMake`. In the following example, we
assume that the library ``mlx_ext.metallib`` will always be co-located with
the executable/ shared-library calling the :meth:`register_library` function.
The :meth:`register_library` function takes the library's name and potential
path (or in this case, a function that can produce the path of the metal
library) and tries to load that library if it hasn't already been registered
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
it is important to package your C++ library with the metal library. We will
go over this process in more detail later.
The logic to determine the kernel, set the inputs, resolve the grid dimensions
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
below.
.. code-block:: C++
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
@@ -479,22 +518,23 @@ below.
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
// Allocate output memory
// Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel
// Resolve name of kernel (corresponds to axpby.metal)
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Make sure the metal library is available and look for it
// in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
@@ -502,21 +542,21 @@ below.
size_t nelem = out.size();
// Encode input arrays to kernel
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
// Encode output arrays to kernel
compute_encoder.set_output_array(out, 2);
set_array_buffer(compute_encoder, out, 2);
// Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3);
compute_encoder.set_bytes(beta_, 4);
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
// Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
// We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed
@@ -530,30 +570,33 @@ below.
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
A few things to note about MLX and Metal before moving on. MLX keeps track of
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
associated. We rely on :meth:`d.get_command_encoder` to give us the active
metal compute command encoder instead of building a new one and calling
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
pipelines) to the active command buffer until some specified limit is hit or
the command buffer needs to be flushed for synchronization.
A few things to note about MLX and metal before moving on. MLX keeps track
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
to give us the active metal compute command encoder instead of building a
new one and calling :meth:`compute_encoder->end_encoding` at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
for synchronization. MLX also handles enqueuing and committing the associated
command buffers as needed. We suggest taking a deeper dive into
:class:`metal::Device` if you would like to study this routine further.
Primitive Transforms
^^^^^^^^^^^^^^^^^^^^^
Next, let's add implementations for transformations in a :class:`Primitive`.
These transformations can be built on top of other operations, including the
one we just defined:
Now that we have come this far, let's also learn how to add implementations to
transformations in a :class:`Primitive`. These transformations can be built on
top of our operations, including the one we just defined now. Which then gives
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
.. code-block:: C++
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
array Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
@@ -568,12 +611,12 @@ one we just defined:
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
return multiply(scale_arr, tangents[0], stream());
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
}
}
@@ -582,35 +625,34 @@ one we just defined:
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<int>& /* unused */) {
const array& cotan,
const std::vector<int>& argnums) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
auto scale_arr = array(scale, cotan.dtype());
vjps.push_back(multiply(scale_arr, cotan, stream()));
}
return vjps;
}
Note, a transformation does not need to be fully defined to start using
the :class:`Primitive`.
Finally, you need not have a transformation fully defined to start using your
own :class:`Primitive`.
.. code-block:: C++
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("[Axpby] vmap not implemented.");
throw std::runtime_error("Axpby has no vmap implementation.");
}
Building and Binding
--------------------
Let's look at the overall directory structure first.
Let's look at the overall directory structure first.
| extensions
| ├── axpby
@@ -624,39 +666,40 @@ Let's look at the overall directory structure first.
| └── setup.py
* ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated Python package
* ``extensions/bindings.cpp`` provides Python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
Python bindings
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated python package
* ``extensions/bindings.cpp`` provides python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
python bindings
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
the Python package
the python package
Binding to Python
^^^^^^^^^^^^^^^^^^
We use nanobind_ to build a Python API for the C++ library. Since bindings for
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple.
already provided, adding our :meth:`axpby` is simple!
.. code-block:: C++
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
@@ -668,17 +711,17 @@ already provided, adding our :meth:`axpby` is simple.
Returns:
array: ``alpha * x + beta * y``
)");
)pbdoc");
}
Most of the complexity in the above example comes from additional bells and
Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings.
.. warning::
:mod:`mlx.core` must be imported before importing
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
ensure that the casters for :mod:`mlx.core` components like
:mod:`mlx.core` needs to be imported before importing
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
ensure that the casters for :mod:`mlx.core` components like
:class:`mlx.core.array` are available.
.. _Building with CMake:
@@ -686,8 +729,8 @@ whistles such as the literal names and doc-strings.
Building with CMake
^^^^^^^^^^^^^^^^^^^^
Building the C++ extension library only requires that you ``find_package(MLX
CONFIG)`` and then link it to your library.
Building the C++ extension library itself is simple, it only requires that you
``find_package(MLX CONFIG)`` and then link it to your library.
.. code-block:: cmake
@@ -709,12 +752,12 @@ CONFIG)`` and then link it to your library.
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
We also need to build the attached Metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
We also need to build the attached metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
Here is what that looks like in practice:
Here is what that looks like in practice!
.. code-block:: cmake
@@ -736,29 +779,27 @@ Here is what that looks like in practice:
endif()
Finally, we build the nanobind_ bindings
Finally, we build the Pybind11_ bindings
.. code-block:: cmake
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
pybind11_add_module(
mlx_sample_extensions
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()
Building with ``setuptools``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once we have set out the CMake build rules as described above, we can use the
build utilities defined in :mod:`mlx.extension`:
build utilities defined in :mod:`mlx.extension` for a simple build process.
.. code-block:: python
.. code-block:: python
from mlx import extension
from setuptools import setup
@@ -768,50 +809,48 @@ build utilities defined in :mod:`mlx.extension`:
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev":[]},
packages = ["mlx_sample_extensions"],
package_dir = {"": "mlx_sample_extensions"},
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
zip_safe=False,
python_requires=">=3.8",
python_requires=">=3.7",
)
.. note::
We treat ``extensions/mlx_sample_extensions`` as the package directory
even though it only contains a ``__init__.py`` to ensure the following:
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
To build the package, first install the build dependencies with ``pip install
-r requirements.txt``. You can then build inplace for development using
You can build inplace for development using
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
This results in the directory structure:
This will result in a directory structure as follows:
| extensions
| ├── mlx_sample_extensions
| │ ├── __init__.py
| │ ├── libmlx_ext.dylib # C++ extension library
| │ ├── mlx_ext.metallib # Metal library
| │ └── _ext.cpython-3x-darwin.so # Python Binding
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
| ...
When you try to install using the command ``python -m pip install .`` (in
``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
copied along with the Python binding since they are specified as
``package_data``.
When you try to install using the command ``python -m pip install .``
(in ``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
copied along with the python binding since they are specified as ``package_data``.
Usage
-----
After installing the extension as described above, you should be able to simply
import the Python package and play with it as you would any other MLX operation.
After installing the extension as described above, you should be able to simply
import the python package and play with it as you would any other MLX operation!
Let's look at a simple script and its results:
Let's looks at a simple script and it's results!
.. code-block:: python
@@ -824,7 +863,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")
print(f"c correctness: {mx.all(c == 6.0).item()}")
Output:
@@ -835,12 +874,12 @@ Output:
c correctness: True
Results
^^^^^^^
^^^^^^^^^^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we first defined on the CPU.
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we defined at first on the CPU.
.. code-block:: python
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
@@ -859,7 +898,7 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
alpha = 4.0
beta = 2.0
mx.eval(x, y)
mx.eval((x, y))
def bench(f):
# Warm up
@@ -880,23 +919,30 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
modest improvements right away!
Results:
.. code-block::
Simple axpby: 0.114 s | Custom axpby: 0.109 s
We see some modest improvements right away!
This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`.
:meth:`grad`!
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
The full example code is available in `mlx <code>`_.
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
.. _nanobind: https://nanobind.readthedocs.io/en/latest/
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/

View File

@@ -1,68 +0,0 @@
Metal Debugger
==============
.. currentmodule:: mlx.core
Profiling is a key step for performance optimization. You can build MLX with
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
* Records source during Metal compilation, for later inspection while
debugging.
* Labels Metal objects such as command queues, improving capture readability.
To build with debugging enabled in Python prepend
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
work.
.. note::
To capture a GPU trace you must run the application with
``MTL_CAPTURE_ENABLED=1``.
.. code-block:: python
import mlx.core as mx
a = mx.random.uniform(shape=(512, 512))
b = mx.random.uniform(shape=(512, 512))
mx.eval(a, b)
trace_file = "mlx_trace.gputrace"
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
# that the path trace_file does not already exist.
mx.metal.start_capture(trace_file)
for _ in range(10):
mx.eval(mx.add(a, b))
mx.metal.stop_capture()
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
has a great overview of all operations. Checkout the `Metal debugger
documentation`_ for more information.
.. image:: ../_static/metal_debugger/capture.png
:class: dark-light
Xcode Workflow
--------------
You can skip saving to a path by running within Xcode. First, generate an
Xcode project using CMake.
.. code-block::
mkdir build && cd build
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
open mlx.xcodeproj
Select the ``metal_capture`` example schema and run.
.. image:: ../_static/metal_debugger/schema.png
:class: dark-light
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger

View File

@@ -15,7 +15,7 @@ module to concisely define the model architecture.
Attention layer
^^^^^^^^^^^^^^^^
We will start with the Llama attention layer which notably uses the RoPE
We will start with the llama attention layer which notably uses the RoPE
positional encoding. [1]_ In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.

View File

@@ -64,7 +64,7 @@ set:
Next, setup the problem parameters and load the data. To load the data, you need our
`mnist data loader
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
we will import as ``mnist``.
we will import as `mnist`.
.. code-block:: python

View File

@@ -43,7 +43,6 @@ are the CPU and GPU.
usage/function_transforms
usage/compile
usage/numpy
usage/distributed
usage/using_streams
.. toctree::
@@ -59,18 +58,15 @@ are the CPU and GPU.
:maxdepth: 1
python/array
python/data_types
python/devices_and_streams
python/ops
python/random
python/transforms
python/fast
python/fft
python/linalg
python/metal
python/nn
python/optimizers
python/distributed
python/tree_utils
.. toctree::
@@ -84,5 +80,3 @@ are the CPU and GPU.
:maxdepth: 1
dev/extensions
dev/metal_debugger
dev/custom_metal_kernels

View File

@@ -14,11 +14,11 @@ silicon computer is
To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
- macOS >= 13.5
- Using a native Python >= 3.8
- macOS >= 13.3
.. note::
MLX is only available on devices running macOS >= 13.5
MLX is only available on devices running macOS >= 13.3
It is highly recommended to use macOS 14 (Sonoma)
@@ -54,7 +54,7 @@ Build Requirements
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- Xcode >= 15.0 and macOS SDK >= 14.0
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
.. note::
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
@@ -70,36 +70,39 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Then simply build and install MLX using pip:
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
pip install "pybind11[global]"
conda install pybind11
brew install pybind11
For developing, install the package with development dependencies, and use an
editable install:
Then simply build and install it using pip:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
Once the development dependencies are installed, you can build faster with:
For developing use an editable install:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
Run the tests with:
To make sure the install is working run the tests with:
.. code-block:: shell
pip install ".[testing]"
python -m unittest discover python/tests
Optional: Install stubs to enable auto completions and type checking from your
IDE:
Optional: Install stubs to enable auto completions and type checking from your IDE:
.. code-block:: shell
pip install ".[dev]"
python setup.py generate_stubs
C++ API
@@ -120,7 +123,7 @@ Create a build directory and run CMake and make:
.. code-block:: shell
mkdir -p build && cd build
cmake .. && make -j
cmake .. && make -j
Run tests with:
@@ -139,7 +142,7 @@ directory as the executable statically linked to ``libmlx.a`` or the
preprocessor constant ``METAL_PATH`` should be defined at build time and it
should point to the path to the built metal library.
.. list-table:: Build Options
.. list-table:: Build Options
:widths: 25 8
:header-rows: 1
@@ -153,67 +156,31 @@ should point to the path to the built metal library.
- OFF
* - MLX_BUILD_METAL
- ON
* - MLX_BUILD_CPU
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
* - MLX_METAL_DEBUG
- OFF
* - MLX_BUILD_SAFETENSORS
- ON
* - MLX_BUILD_GGUF
- ON
* - MLX_METAL_JIT
- OFF
.. note::
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
.. code-block:: shell
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which
Further, you can use the following command to find out which
macOS SDK will be used
.. code-block:: shell
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do:
.. code-block:: shell
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots.
Troubleshooting
^^^^^^^^^^^^^^^
Metal not found
~~~~~~~~~~~~~~~
@@ -235,12 +202,12 @@ Then set the active developer directory:
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
x86 Shell
x86 Shell
~~~~~~~~~
.. _build shell:
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm,
@@ -264,4 +231,4 @@ Also check that cmake is using the correct architecture:
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cache with ``rm -rf build/`` and try again.
wipe your build cahce with ``rm -rf build/`` and try again.

View File

@@ -10,39 +10,27 @@ Array
array
array.astype
array.at
array.item
array.tolist
array.dtype
array.itemsize
array.nbytes
array.ndim
array.shape
array.size
Dtype
array.abs
array.all
array.any
array.argmax
array.argmin
array.conj
array.cos
array.cummax
array.cummin
array.cumprod
array.cumsum
array.diag
array.diagonal
array.dtype
array.exp
array.flatten
array.log
array.log10
array.log1p
array.log2
array.logsumexp
array.max
array.mean
array.min
array.moveaxis
array.prod
array.reciprocal
array.reshape
@@ -52,11 +40,7 @@ Array
array.split
array.sqrt
array.square
array.squeeze
array.std
array.sum
array.swapaxes
array.transpose
array.T
array.var
array.view

View File

@@ -1,5 +1,7 @@
.. _data_types:
:orphan:
Data Types
==========
@@ -42,27 +44,9 @@ The default floating point type is ``float32`` and the default integer type is
* - ``int64``
- 8
- 64-bit signed integer
* - ``bfloat16``
- 2
- 16-bit brain float (e8, m7)
* - ``float16``
- 2
- 16-bit IEEE float (e5, m10)
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
* - ``float32``
- 4
- 32-bit float
* - ``complex64``
- 8
- 64-bit complex float
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category.
.. autosummary::
:toctree: _autosummary
Dtype
DtypeCategory
issubdtype

View File

@@ -16,4 +16,3 @@ Devices and Streams
new_stream
set_default_stream
stream
synchronize

View File

@@ -1,22 +0,0 @@
.. _distributed:
.. currentmodule:: mlx.core.distributed
Distributed Communication
==========================
MLX provides a distributed communication package using MPI. The MPI library is
loaded at runtime; if MPI is available then distributed communication is also
made available.
.. autosummary::
:toctree: _autosummary
Group
is_available
init
all_sum
all_gather
send
recv
recv_like

View File

@@ -1,15 +0,0 @@
.. _fast:
Fast
====
.. currentmodule:: mlx.core.fast
.. autosummary::
:toctree: _autosummary
rms_norm
layer_norm
rope
scaled_dot_product_attention
metal_kernel

View File

@@ -8,13 +8,5 @@ Linear Algebra
.. autosummary::
:toctree: _autosummary
inv
tri_inv
norm
cholesky
cholesky_inv
cross
qr
svd
eigvalsh
eigh

View File

@@ -3,18 +3,12 @@ Metal
.. currentmodule:: mlx.core.metal
.. autosummary::
.. autosummary::
:toctree: _autosummary
is_available
device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture
stop_capture

View File

@@ -173,7 +173,7 @@ In detail:
:toctree: _autosummary
value_and_grad
quantize
checkpoint
.. toctree::

View File

@@ -13,13 +13,10 @@ simple functions.
:template: nn-module-template.rst
elu
celu
gelu
gelu_approx
gelu_fast_approx
glu
hard_shrink
hard_tanh
hardswish
leaky_relu
log_sigmoid
@@ -32,7 +29,6 @@ simple functions.
sigmoid
silu
softmax
softmin
softplus
softshrink
step

View File

@@ -12,58 +12,32 @@ Layers
ALiBi
AvgPool1d
AvgPool2d
AvgPool3d
BatchNorm
CELU
Conv1d
Conv2d
Conv3d
ConvTranspose1d
ConvTranspose2d
ConvTranspose3d
Dropout
Dropout2d
Dropout3d
Embedding
ELU
GELU
GLU
GroupNorm
GRU
HardShrink
HardTanh
Hardswish
InstanceNorm
LayerNorm
LeakyReLU
Linear
LogSigmoid
LogSoftmax
LSTM
MaxPool1d
MaxPool2d
MaxPool3d
Mish
MultiHeadAttention
PReLU
QuantizedEmbedding
QuantizedLinear
RMSNorm
ReLU
ReLU6
RNN
RoPE
SELU
Sequential
Sigmoid
SiLU
SinusoidalPositionalEncoding
Softmin
Softshrink
Softsign
Softmax
Softplus
Step
Tanh
Transformer
Upsample
Upsample

View File

@@ -30,7 +30,6 @@ Module
Module.named_modules
Module.parameters
Module.save_weights
Module.set_dtype
Module.train
Module.trainable_parameters
Module.unfreeze

View File

@@ -5,14 +5,13 @@ Operations
.. currentmodule:: mlx.core
.. autosummary::
.. autosummary::
:toctree: _autosummary
abs
add
addmm
all
allclose
allclose
any
arange
arccos
@@ -20,76 +19,49 @@ Operations
arcsin
arcsinh
arctan
arctan2
arctanh
argmax
argmin
argpartition
argsort
array_equal
as_strided
atleast_1d
atleast_2d
atleast_3d
bitwise_and
bitwise_or
bitwise_xor
block_masked_mm
broadcast_to
ceil
clip
concatenate
conj
conjugate
convolve
conv1d
conv2d
conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
conv_general
cos
cosh
cummax
cummin
cumprod
cumsum
degrees
dequantize
diag
diagonal
divide
divmod
einsum
einsum_path
equal
erf
erfinv
exp
expm1
expand_dims
eye
flatten
floor
floor_divide
full
gather_mm
gather_qmm
greater
greater_equal
hadamard_transform
identity
imag
inner
isfinite
isclose
isinf
isnan
isneginf
isposinf
issubdtype
left_shift
isneginf
isinf
less
less_equal
linspace
@@ -107,32 +79,22 @@ Operations
max
maximum
mean
meshgrid
min
minimum
moveaxis
multiply
nan_to_num
negative
not_equal
ones
ones_like
outer
partition
pad
power
prod
put_along_axis
quantize
quantized_matmul
radians
real
reciprocal
remainder
repeat
reshape
right_shift
roll
round
rsqrt
save
@@ -151,7 +113,6 @@ Operations
square
squeeze
stack
std
stop_gradient
subtract
sum
@@ -163,13 +124,11 @@ Operations
tensordot
tile
topk
trace
transpose
tri
tril
triu
var
view
where
zeros
zeros_like

View File

@@ -1,7 +1,5 @@
.. _optimizers:
.. currentmodule:: mlx.optimizers
Optimizers
==========
@@ -31,48 +29,8 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
Saving and Loading
------------------
To serialize an optimizer, save its state. To load an optimizer, load and set
the saved state. Here's a simple example:
.. code-block:: python
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
import mlx.optimizers as optim
optimizer = optim.Adam(learning_rate=1e-2)
# Perform some updates with the optimizer
model = {"w" : mx.zeros((5, 5))}
grads = {"w" : mx.ones((5, 5))}
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
parameters are not. A good rule of thumb is if the parameter can be scheduled
then it will be included in the optimizer state.
.. toctree::
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers
.. autosummary::
:toctree: _autosummary
clip_grad_norm

View File

@@ -38,11 +38,8 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
gumbel
key
normal
multivariate_normal
randint
seed
split
truncated_normal
uniform
laplace
permutation

View File

@@ -10,7 +10,6 @@ Transforms
eval
compile
custom_function
disable_compile
enable_compile
grad
@@ -18,3 +17,4 @@ Transforms
jvp
vjp
vmap
checkpoint

View File

@@ -19,5 +19,3 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_flatten
tree_unflatten
tree_map
tree_map_with_path
tree_reduce

View File

@@ -33,12 +33,12 @@ Let's start with a simple example:
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
@@ -96,7 +96,7 @@ element-wise operations:
.. code-block:: python
def gelu(x):
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
@@ -136,6 +136,13 @@ Now make an array, and benchmark both functions:
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging
---------
@@ -280,7 +287,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
print(fun(mx.array(1.0)))
Compiling Training Graphs
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
@@ -290,7 +297,7 @@ full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
@@ -323,7 +330,7 @@ To start, here is the simple example without any compilation:
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
@@ -348,7 +355,7 @@ appropriate input and output captures. Here's the same example but compiled:
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
@@ -403,7 +410,7 @@ Compiling transformed functions works just as expected:
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`

View File

@@ -1,166 +0,0 @@
.. _usage_distributed:
Distributed Communication
=========================
.. currentmodule:: mlx.core.distributed
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
provide distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. You can
see a list of the supported operations in the :ref:`API docs<distributed>`.
.. note::
A lot of operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. The minimal distributed program in MLX is as simple as:
.. code:: python
import mlx.core as mx
world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)
The program above sums the array ``mx.ones(10)`` across all
distributed processes. If simply run with ``python``, however, only one
process is launched and no distributed communication takes place.
To launch the program in distributed mode we need to use ``mpirun`` or
``mpiexec`` depending on the MPI installation. The simplest possible way is the
following:
.. code:: shell
$ mpirun -np 2 python test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
print 4 etc.
Installing MPI
---------------
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
Setting up Remote Hosts
-----------------------
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
full path to force all machines to use a specific path.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
.. note::
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
An easy way to pass the host names to MPI is using a host file. A host file
looks like the following, where ``host1`` and ``host2`` should be the fully
qualified domain names or IPs for these hosts.
.. code::
host1 slots=1
host2 slots=1
When using MLX, it is very likely that you want to use 1 slot per host, ie one
process per host. The hostfile also needs to contain the current
host if you want to run on the local host. Passing the host file to
``mpirun`` is simply done using the ``--hostfile`` command line argument.
Training Example
----------------
In this section we will adapt an MLX training loop to support data parallel
distributed training. Namely, we will average the gradients across a set of
hosts before applying them to the model.
Our training loop looks like the following code snippet if we omit the model,
dataset and optimizer initialization.
.. code:: python
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
All we have to do to average the gradients across machines is perform an
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
have to :func:`mlx.utils.tree_map` the gradients with following function.
.. code:: python
def all_avg(x):
return mx.distributed.all_sum(x) / mx.distributed.init().size()
Putting everything together our training loop step looks as follows with
everything else remaining the same.
.. code:: python
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads)
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = all_reduce_grads(grads) # <--- This line was added
optimizer.update(model, grads)
return loss
Tuning All Reduce
-----------------
We are working on improving the performance of all reduce on MLX but for now
the two main things one can do to extract the most out of distributed training with MLX are:
1. Perform a few large reductions instead of many small ones to improve
bandwidth and latency
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
connections between each host to improve bandwidth

View File

@@ -25,7 +25,7 @@ Here is a simple example:
The output of :func:`grad` on :func:`sin` is simply another function. In this
case it is the gradient of the sine function which is exactly the cosine
function. To get the second derivative you can do:
function. To get the second derivative you can do:
.. code-block:: shell
@@ -40,7 +40,7 @@ getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
depth. See the following sections for more information on :ref:`automatic
differentiation <auto diff>` and :ref:`automatic vectorization <vmap>`.
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
@@ -50,7 +50,7 @@ Automatic Differentiation
.. _auto diff:
Automatic differentiation in MLX works on functions rather than on implicit
graphs.
graphs.
.. note::
@@ -114,7 +114,7 @@ way to do that is the following:
def loss_fn(params, x, y):
w, b = params["weight"], params["bias"]
h = w * x + b
h = w * x + b
return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
@@ -132,7 +132,7 @@ way to do that is the following:
Notice the tree structure of the parameters is preserved in the gradients.
In some cases you may want to stop gradients from propagating through a
In some cases you may want to stop gradients from propagating through a
part of the function. You can use the :func:`stop_gradient` for that.
@@ -161,19 +161,19 @@ A naive way to add the elements from two sets of vectors is with a loop:
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
Instead you can use :func:`vmap` to automatically vectorize the addition:
.. code-block:: python
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
where the vectorized axes should be in the outputs.
where the vectorized axes should be in the outputs.
Let's time these two different versions:
@@ -184,8 +184,8 @@ Let's time these two different versions:
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
vectorized version takes only ``0.025`` seconds, more than ten times faster.
Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

View File

@@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
.. code-block:: shell
>>> arr = mx.arange(10)
>>> idx = mx.array([5, 7])
>>> idx = mx.array([5, 7])
>>> arr[idx]
array([5, 7], dtype=int32)
@@ -77,12 +77,12 @@ from the GPU. Performing bounds checking for array indices before launching the
kernel would be extremely inefficient.
Indexing with boolean masks is something that MLX may support in the future. In
general, MLX has limited support for operations for which output
general, MLX has limited support for operations for which outputs
*shapes* are dependent on input *data*. Other examples of these types of
operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`.
In Place Updates
In Place Updates
----------------
In place updates to indexed arrays are possible in MLX. For example:

View File

@@ -13,12 +13,12 @@ compute graph is recorded. The actual computation only happens if an
:func:`eval` is performed.
MLX uses lazy evaluation because it has some nice features, some of which we
describe below.
describe below.
Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Lazy evaluation lets us record a compute graph without actually doing any
Lazy evaluation let's us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations.
@@ -109,14 +109,14 @@ Here is a concrete example:
An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you ``print`` an array, convert it to an
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
saving functions) will also evaluate the array.
Calling :func:`array.item` on a scalar array will also evaluate it. In the
example above, printing the loss (``print(loss)``) or adding the loss scalar to
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
these lines are before ``mx.eval(loss, model.parameters())`` then this
will be a partial evaluation, computing only the forward pass.

View File

@@ -3,11 +3,7 @@
Conversion to NumPy and Other Frameworks
========================================
MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
Let's convert an array to NumPy and back.
.. code-block:: python
@@ -66,7 +62,7 @@ even though no in-place operations on MLX memory are executed.
PyTorch
-------
.. warning::
.. warning::
PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now.

View File

@@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
and :func:`jvp` for Jacobian-vector products.
Use :func:`value_and_grad` to efficiently compute both a function's output and
gradient with respect to the function's input.
gradient with respect to the function's input.

View File

@@ -8,33 +8,33 @@ Saving and Loading Arrays
MLX supports multiple array serialization formats.
.. list-table:: Serialization Formats
:widths: 20 8 25 25
:widths: 20 8 25 25
:header-rows: 1
* - Format
- Extension
* - Format
- Extension
- Function
- Notes
* - NumPy
- ``.npy``
- Notes
* - NumPy
- ``.npy``
- :func:`save`
- Single arrays only
* - NumPy archive
- ``.npz``
* - NumPy archive
- ``.npz``
- :func:`savez` and :func:`savez_compressed`
- Multiple arrays
- Multiple arrays
* - Safetensors
- ``.safetensors``
- ``.safetensors``
- :func:`save_safetensors`
- Multiple arrays
* - GGUF
- ``.gguf``
- Multiple arrays
* - GGUF
- ``.gguf``
- :func:`save_gguf`
- Multiple arrays
The :func:`load` function will load any of the supported serialization
formats. It determines the format from the extensions. The output of
:func:`load` depends on the format.
:func:`load` depends on the format.
Here's an example of saving a single array to a file:
@@ -49,7 +49,7 @@ it will be added. You can load the array with:
.. code-block:: shell
>>> mx.load("array.npy")
>>> mx.load("array.npy", a)
array([1], dtype=float32)
Here's an example of saving several arrays to a single file:

View File

@@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
In MLX, rather than moving arrays to devices, you specify the device when you
run the operation. Any device can perform any operation on ``a`` and ``b``
without needing to move them from one memory location to another. For example:
without needing to move them from one memory location to another. For example:
.. code-block:: python

View File

@@ -8,5 +8,3 @@ endfunction(build_example)
build_example(tutorial.cpp)
build_example(linear_regression.cpp)
build_example(logistic_regression.cpp)
build_example(metal_capture.cpp)
build_example(distributed.cpp)

View File

@@ -1,22 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
if (!distributed::is_available()) {
std::cout << "No communication backend found" << std::endl;
return 1;
}
auto global_group = distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
array x = ones({10});
array out = distributed::all_sum(x, global_group);
std::cout << out << std::endl;
}

View File

@@ -1,31 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
// To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1.
metal::start_capture("mlx_trace.gputrace");
// Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each
// stream's command queue.
auto s2 = new_stream(Device::gpu);
auto s3 = new_stream(Device::gpu);
auto a = arange(1.f, 10.f, 1.f, float32, s2);
auto b = arange(1.f, 10.f, 1.f, float32, s3);
auto x = add(a, a, s2);
auto y = add(b, b, s3);
// The multiply will happen on the default stream.
std::cout << multiply(x, y) << std::endl;
metal::stop_capture();
}

View File

@@ -89,8 +89,8 @@ void automatic_differentiation() {
// dfdx is 2 * x
// Get the second derivative by composing grad with grad
auto d2fdx2 = grad(grad(fn))(x);
// d2fdx2 is 2
auto df2dx2 = grad(grad(fn))(x);
// df2dx2 is 2
}
int main() {

View File

@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.27)
project(_ext LANGUAGES CXX)
project(mlx_sample_extensions LANGUAGES CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17)
@@ -11,16 +11,8 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED)
find_package(
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
# ----------------------------- Extensions -----------------------------
@@ -28,10 +20,16 @@ find_package(nanobind CONFIG REQUIRED)
add_library(mlx_ext)
# Add sources
target_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp)
target_sources(
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
# Add include headers
target_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})
target_include_directories(
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
@@ -40,35 +38,29 @@ target_link_libraries(mlx_ext PUBLIC mlx)
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET
mlx_ext_metallib
TITLE
mlx_ext
SOURCES
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS
${PROJECT_SOURCE_DIR}
${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
add_dependencies(mlx_ext mlx_ext_metallib)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)
add_dependencies(
mlx_ext
mlx_ext_metallib
)
endif()
# ----------------------------- Python Bindings -----------------------------
nanobind_add_module(
_ext
NB_STATIC
STABLE_ABI
LTO
NOMINSIZE
NB_DOMAIN
mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
target_link_libraries(_ext PRIVATE mlx_ext)
# ----------------------------- Pybind -----------------------------
pybind11_add_module(
mlx_sample_extensions
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()

View File

@@ -1,24 +0,0 @@
## Build
```
pip install -e .
```
For faster builds during development, you can also pre-install the requirements:
```
pip install -r requirements.txt
```
And then run:
```
python setup.py build_ext -j8 --inplace
```
## Test
```
python test.py
```

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <iostream>
@@ -43,7 +43,7 @@ array axpby(
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32)
auto out_dtype = is_floating_point(promoted_dtype)
? promoted_dtype
: promote_types(promoted_dtype, float32);
@@ -61,7 +61,7 @@ array axpby(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
std::make_unique<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
@@ -106,12 +106,12 @@ void axpby_impl(
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
std::vector<array>& out_arr) {
auto out = out_arr[0];
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
@@ -150,7 +150,11 @@ void axpby_impl_accelerate(
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
@@ -176,11 +180,11 @@ void axpby_impl_accelerate(
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
std::vector<array>& outarr) {
auto out = outarr[0];
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
@@ -191,7 +195,7 @@ void Axpby::eval_cpu(
}
// Fall back to common backend if specializations are not available
eval(inputs, outputs);
eval(inputs, outarr);
}
#else // Accelerate not available
@@ -199,8 +203,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
eval(inputs, outputs);
std::vector<array>& out) {
eval(inputs, out);
}
#endif
@@ -214,12 +218,12 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
std::vector<array>& outarr) {
// Prepare inputs
auto out = outarr[0];
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
@@ -249,15 +253,16 @@ void Axpby::eval_gpu(
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Make sure the metal library is available and look for it
// in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal
@@ -265,22 +270,22 @@ void Axpby::eval_gpu(
size_t nelem = out.size();
// Encode input arrays to kernel
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
// Encode output arrays to kernel
compute_encoder.set_output_array(out, 2);
set_array_buffer(compute_encoder, out, 2);
// Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3);
compute_encoder.set_bytes(beta_, 4);
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim if needed
if (!contiguous_kernel) {
compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8);
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
}
// We launch 1 thread for each input and make sure that the number of
@@ -295,7 +300,7 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
#else // Metal is not available
@@ -367,4 +372,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -33,7 +33,7 @@ array axpby(
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta) {};
: Primitive(stream), alpha_(alpha), beta_(beta){};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
@@ -42,9 +42,9 @@ class Axpby : public Primitive {
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out)
override;
/** The Jacobian-vector product. */
@@ -83,7 +83,7 @@ class Axpby : public Primitive {
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
void eval(const std::vector<array>& inputs, std::vector<array>& out);
};
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -2,6 +2,7 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T>
@@ -18,7 +19,7 @@ template <typename T>
uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
out[index] =
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
@@ -30,33 +31,33 @@ template <typename T>
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
out[index] =
out[index] =
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
}
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \
[[kernel]] void axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] \
[[kernel]] void axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
instantiate_axpby(complex64, complex64_t);

View File

@@ -1,31 +1,31 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/variant.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "axpby/axpby.h"
namespace nb = nanobind;
using namespace nb::literals;
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
@@ -37,5 +37,5 @@ NB_MODULE(_ext, m) {
Returns:
array: ``alpha * x + beta * y``
)");
}
)pbdoc");
}

View File

@@ -2,4 +2,4 @@
import mlx.core as mx
from ._ext import axpby
from .mlx_sample_extensions import *

View File

@@ -1,8 +1,3 @@
[build-system]
requires = [
"setuptools>=42",
"cmake>=3.24",
"mlx>=0.18.0",
"nanobind==2.2.0",
]
build-backend = "setuptools.build_meta"
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"]
build-backend = "setuptools.build_meta"

View File

@@ -1,4 +0,0 @@
setuptools>=42
cmake>=3.24
mlx>=0.21.0
nanobind==2.2.0

View File

@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
# Copyright © 2023 Apple Inc.
from setuptools import setup
@@ -9,9 +9,10 @@ if __name__ == "__main__":
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_dir={"": "."},
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
zip_safe=False,
python_requires=">=3.8",

View File

@@ -1,10 +0,0 @@
import mlx.core as mx
from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")

View File

@@ -28,19 +28,10 @@ endif()
if (@MLX_BUILD_METAL@)
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
set_and_check(MLX_INCLUDE_DIRS
${MLX_INCLUDE_DIRS}
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
)
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
else()
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
endif()
endif()
set_target_properties(mlx PROPERTIES
@@ -49,4 +40,4 @@ set_target_properties(mlx PROPERTIES
)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)

View File

@@ -1,40 +1,37 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
)
if(MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_ACCELERATE)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
elseif(MLX_BUILD_CPU)
else()
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
)
endif()
if(MLX_BUILD_METAL)
if (MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)

View File

@@ -19,26 +19,15 @@ Buffer malloc(size_t size) {
}
void free(Buffer buffer) {
allocator().free(buffer);
return allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size, bool) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
return Buffer{std::malloc(size)};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
std::free(buffer.raw_ptr());
}
Buffer malloc_or_wait(size_t size) {

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
Buffer(void* ptr) : ptr_(ptr) {};
Buffer(void* ptr) : ptr_(ptr){};
// Get the raw data pointer from the buffer
void* raw_ptr();
@@ -41,7 +41,6 @@ class Allocator {
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
Allocator() = default;
Allocator(const Allocator& other) = delete;
@@ -58,7 +57,6 @@ class CommonAllocator : public Allocator {
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;

View File

@@ -1,6 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include <functional>
#include <unordered_map>
#include "mlx/array.h"
#include "mlx/ops.h"
@@ -12,16 +12,22 @@ namespace mlx::core {
namespace {
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
std::vector<size_t> strides(shape.size());
size_t cum_prod = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = cum_prod;
cum_prod *= shape[i];
}
return {cum_prod, strides};
}
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
bool in_tracing() {
return detail::InTracing::in_tracing();
}
bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
@@ -30,11 +36,22 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
init(&cval);
}
array::array(
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
shape,
dtype,
std::move(primitive),
inputs)) {}
array::array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
std::vector<array>&& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
std::move(shape),
dtype,
@@ -42,16 +59,15 @@ array::array(
std::move(inputs))) {}
std::vector<array> array::make_arrays(
std::vector<std::vector<int>> shapes,
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs) {
std::vector<array> outputs;
for (size_t i = 0; i < shapes.size(); ++i) {
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
for (int i = 0; i < shapes.size(); ++i) {
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
}
// For each node in |outputs|, its siblings are the other nodes.
for (size_t i = 0; i < outputs.size(); ++i) {
for (int i = 0; i < outputs.size(); ++i) {
auto siblings = outputs;
siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i);
@@ -76,10 +92,10 @@ array::array(std::initializer_list<int> data, Dtype dtype)
/* Build an array from a shared buffer */
array::array(
allocator::Buffer data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
deleter_t deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
set_data(data, deleter);
}
@@ -88,42 +104,22 @@ void array::detach() {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->depth = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
array_desc_->depth = 0;
array_desc_->primitive = nullptr;
}
bool array::is_available() const {
if (status() == Status::available) {
return true;
} else if (status() == Status::evaluated && event().is_signaled()) {
set_status(Status::available);
return true;
}
return false;
}
void array::wait() {
if (!is_available()) {
event().wait();
set_status(Status::available);
}
}
void array::eval() {
// Ensure the array is ready to be read
if (status() == Status::unscheduled) {
mlx::core::eval({*this});
} else {
wait();
}
mlx::core::eval({*this});
}
bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing() || retain_graph();
return array_desc_->is_tracer && in_tracing();
}
void array::set_data(allocator::Buffer buffer, deleter_t d) {
@@ -168,142 +164,51 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}
void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = other.strides();
array_desc_->flags = other.flags();
array_desc_->data_size = other.data_size();
array_desc_->data_ptr = other.array_desc_->data_ptr;
}
array::~array() {
if (array_desc_ == nullptr) {
return;
}
// Ignore arrays that might be detached during eval
if (status() == array::Status::scheduled) {
return;
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;
// If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1)
// then there are no more external references
do_detach &= (array_desc_.use_count() == (n + 1));
for (auto& s : siblings()) {
do_detach &= (s.array_desc_.use_count() == n);
if (!do_detach) {
break;
}
}
if (do_detach) {
for (auto& s : siblings()) {
for (auto& ss : s.siblings()) {
// Set to null here to avoid descending into array destructor
// for siblings
ss.array_desc_ = nullptr;
}
s.array_desc_->siblings.clear();
}
}
}
}
void array::ArrayDesc::init() {
strides.resize(shape.size());
size = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = size;
size *= shape[i];
}
for (const auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
init();
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
: shape(shape), dtype(dtype) {
std::tie(size, strides) = cum_prod(shape);
}
array::ArrayDesc::ArrayDesc(
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
: shape(std::move(shape)),
const std::vector<array>& inputs)
: shape(shape),
dtype(dtype),
status(Status::unscheduled),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
init();
inputs(inputs) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayDesc::~ArrayDesc() {
// When an array description is destroyed it will delete a bunch of arrays
// that may also destroy their corresponding descriptions and so on and so
// forth.
//
// This calls recursively the destructor and can result in stack overflow, we
// instead put them in a vector and destroy them one at a time resulting in a
// max stack depth of 2.
if (inputs.empty()) {
return;
}
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {
std::unordered_map<std::uintptr_t, array> input_map;
for (array& a : ad.inputs) {
if (a.array_desc_) {
input_map.insert({a.id(), a});
for (auto& s : a.siblings()) {
input_map.insert({s.id(), s});
}
}
}
ad.inputs.clear();
for (auto& [_, a] : input_map) {
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
};
append_deletable_inputs(*this);
while (!for_deletion.empty()) {
// top is going to be deleted at the end of the block *after* the arrays
// with inputs have been moved into the vector
auto top = std::move(for_deletion.back());
for_deletion.pop_back();
append_deletable_inputs(*top);
// Clear out possible siblings to break circular references
for (auto& s : top->siblings) {
// Set to null here to avoid descending into top-level
// array destructor for siblings
s.array_desc_ = nullptr;
}
top->siblings.clear();
array::ArrayDesc::ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs)
: shape(std::move(shape)),
dtype(dtype),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)

View File

@@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
@@ -9,7 +8,6 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
namespace mlx::core {
@@ -33,7 +31,7 @@ class array {
template <typename It>
array(
It data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>());
@@ -49,13 +47,13 @@ class array {
template <typename T>
array(
std::initializer_list<T> data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */
array(
allocator::Buffer data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
deleter_t deleter = allocator::free);
@@ -73,32 +71,32 @@ class array {
this->array_desc_ = other.array_desc_;
}
return *this;
}
};
/** The size of the array's datatype in bytes. */
size_t itemsize() const {
return size_of(dtype());
}
};
/** The number of elements in the array. */
size_t size() const {
return array_desc_->size;
}
};
/** The number of bytes in the array. */
size_t nbytes() const {
return size() * itemsize();
}
};
/** The number of dimensions of the array. */
size_t ndim() const {
return array_desc_->shape.size();
}
};
/** The shape of the array as a vector of integers. */
const std::vector<int>& shape() const {
return array_desc_->shape;
}
};
/**
* Get the size of the corresponding dimension.
@@ -107,26 +105,17 @@ class array {
* bounds checking. */
int shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim);
}
};
/** The strides of the array. */
const std::vector<size_t>& strides() const {
return array_desc_->strides;
}
/**
* Get the stride of the corresponding dimension.
*
* This function supports negative indexing and provides
* bounds checking. */
size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
}
};
/** Get the arrays data type. */
Dtype dtype() const {
return array_desc_->dtype;
}
};
/** Evaluate the array. */
void eval();
@@ -160,10 +149,10 @@ class array {
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
return a.arr.id() == b.arr.id() && a.idx == b.idx;
}
};
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
return !(a == b);
}
};
private:
const array& arr;
@@ -183,16 +172,22 @@ class array {
* API may change.
*/
array(
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
std::vector<array>&& inputs);
static std::vector<array> make_arrays(
std::vector<std::vector<int>> shapes,
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
/** A unique identifier for an array. */
@@ -209,7 +204,7 @@ class array {
allocator::Buffer buffer;
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d) {}
: buffer(buffer), d(d){};
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;
@@ -219,45 +214,33 @@ class array {
};
struct Flags {
// True iff there are no gaps in the underlying data. Each item
// True if there are no gaps in the underlying data. Each item
// in the underlying data buffer belongs to at least one index.
//
// True iff:
// prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()
bool contiguous : 1;
// True iff:
// strides[-1] == 1 and
// all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in
// range(ndim - 1))
bool row_contiguous : 1;
// True iff:
// strides[0] == 1 and
// all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in
// range(1, ndim))
bool col_contiguous : 1;
};
/** The array's primitive. */
Primitive& primitive() const {
return *(array_desc_->primitive);
}
};
/** A shared pointer to the array's primitive. */
std::shared_ptr<Primitive>& primitive_ptr() const {
return array_desc_->primitive;
}
};
/** Check if the array has an attached primitive or is a leaf node. */
bool has_primitive() const {
return array_desc_->primitive != nullptr;
}
};
/** The array's inputs. */
const std::vector<array>& inputs() const {
return array_desc_->inputs;
}
};
std::vector<array>& inputs() {
return array_desc_->inputs;
@@ -271,12 +254,7 @@ class array {
/** The array's siblings. */
const std::vector<array>& siblings() const {
return array_desc_->siblings;
}
/** The array's siblings. */
std::vector<array>& siblings() {
return array_desc_->siblings;
}
};
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
@@ -293,6 +271,11 @@ class array {
outputs.push_back(*this);
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
return outputs;
};
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
uint16_t graph_depth() const {
return array_desc_->depth;
}
/** Detach the array from the graph. */
@@ -301,32 +284,19 @@ class array {
/** Get the Flags bit-field. */
const Flags& flags() const {
return array_desc_->flags;
}
};
/** The size (in elements) of the underlying buffer the array points to.
*
* This can be different than the actual size of the array if the array has
* been broadcast or irregularly strided. If ``first`` is the offset into
* the data buffer of the first element of the array (i.e. the offset
* corresponding to ``arr[0, 0, ...]``) and last is the offset into the
* data buffer of the last element of the array (i.e. the offset
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
**/
/** The size (in elements) of the underlying buffer the array points to. */
size_t data_size() const {
return array_desc_->data_size;
}
};
allocator::Buffer& buffer() {
return array_desc_->data->buffer;
}
};
const allocator::Buffer& buffer() const {
return array_desc_->data->buffer;
}
size_t buffer_size() const {
return allocator::allocator().size(buffer());
}
};
// Return a copy of the shared pointer
// to the array::Data struct
@@ -337,57 +307,16 @@ class array {
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
}
};
template <typename T>
const T* data() const {
return static_cast<T*>(array_desc_->data_ptr);
}
enum Status {
// The ouptut of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`.
unscheduled,
// The ouptut of a computation which has been scheduled but `eval_*` has
// not yet been called on the array's primitive. A possible
// status of `x` in `auto x = a + b; eval(x);`
scheduled,
// The array's `eval_*` function has been run, but the computation is not
// necessarily complete. The array will have memory allocated and if it is
// not a tracer then it will be detached from the graph.
evaluated,
// If the array is the output of a computation then the computation
// is complete. Constant arrays are always available (e.g. `array({1, 2,
// 3})`)
available
};
// Check if the array is safe to read.
bool is_available() const;
// Wait on the array to be available. After this `is_available` returns
// `true`.
void wait();
Status status() const {
return array_desc_->status;
}
void set_status(Status s) const {
array_desc_->status = s;
}
// Get the array's shared event
Event& event() const {
return array_desc_->event;
}
// Attach an event to a not yet evaluated array
void attach_event(Event e) const {
array_desc_->event = std::move(e);
// Check if the array has been evaluated
bool is_evaled() const {
return array_desc_->data != nullptr;
}
// Mark the array as a tracer array (true) or not.
@@ -415,21 +344,12 @@ class array {
void copy_shared_buffer(const array& other);
void move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_;
}
~array();
private:
// Initialize the arrays data
template <typename It>
@@ -440,12 +360,7 @@ class array {
std::vector<size_t> strides;
size_t size;
Dtype dtype;
std::shared_ptr<Primitive> primitive;
Status status;
// An event on the array used for synchronization
Event event;
std::shared_ptr<Primitive> primitive{nullptr};
// Indicates an array is being used in a graph transform
// and should not be detached from the graph
@@ -453,12 +368,14 @@ class array {
// This is a shared pointer so that *different* arrays
// can share the underlying data buffer.
std::shared_ptr<Data> data;
std::shared_ptr<Data> data{nullptr};
// Properly offset data pointer
void* data_ptr{nullptr};
// The size in elements of the data buffer the array accesses
// This can be different than the actual size of the array if it
// has been broadcast or irregularly strided.
size_t data_size;
// Contains useful meta data about the array
@@ -471,26 +388,29 @@ class array {
// The arrays position in the output list
uint32_t position{0};
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
// The depth of the array in the graph.
uint16_t depth{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc(
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
const std::vector<array>& inputs);
~ArrayDesc();
private:
// Initialize size, strides, and other metadata
void init();
explicit ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs);
};
// The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs
// and the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_;
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
};
template <typename T>
@@ -502,9 +422,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
template <typename It>
array::array(
It data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
init(data);
}
@@ -521,9 +441,9 @@ array::array(
template <typename T>
array::array(
std::initializer_list<T> data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
if (data.size() != size()) {
throw std::invalid_argument(
"Data size and provided shape mismatch in array construction.");
@@ -545,11 +465,10 @@ T array::item() const {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
if (status() == Status::unscheduled) {
if (!is_evaled()) {
throw std::invalid_argument(
"item() const can only be called on evaled arrays");
}
const_cast<array*>(this)->eval();
return *data<T>();
}
@@ -599,15 +518,4 @@ void array::init(It src) {
}
}
/* Utilities for determining whether a template parameter is array. */
template <typename T>
inline constexpr bool is_array_v =
std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
template <typename... T>
inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
template <typename... T>
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
} // namespace mlx::core

View File

@@ -1,8 +1,10 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
)

View File

@@ -1,9 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"

View File

@@ -1,8 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <vecLib/BNNS/bnns.h>
#include <vecLib/cblas_new.h>
#include "mlx/backend/accelerate/utils.h"
#include "mlx/backend/common/copy.h"
@@ -195,40 +196,6 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
return matmul_bnns_general(a_pre, b_pre, out);
}
template <typename T>
inline void mask_matrix(
T* data,
const bool* mask,
int tile_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str) {
int tX = (X + tile_size - 1) / tile_size;
int tY = (Y + tile_size - 1) / tile_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
if (!do_mask) {
int loc_x = i * tile_size;
int loc_y = j * tile_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
int size_x = std::min(tile_size, X - loc_x);
int size_y = std::min(tile_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
}
}
}
}
}
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {

View File

@@ -3,7 +3,8 @@
#include <cassert>
#include <cmath>
#include <Accelerate/Accelerate.h>
#include <vecLib/vDSP.h>
#include <vecLib/vForce.h>
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
@@ -30,27 +31,21 @@ DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
@@ -73,15 +68,10 @@ DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
@@ -103,7 +93,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
binary(
a,
b,
out,
@@ -118,7 +108,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary_op<int>(
binary(
a,
b,
out,
@@ -133,7 +123,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
});
} else {
eval(inputs, out);
binary(a, b, out, [](auto x, auto y) { return x + y; });
}
}
@@ -197,26 +187,6 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
if (a.is_donatable()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
int size = a.data_size();
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -288,7 +258,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == int32) {
binary_op<int>(
binary(
a,
b,
out,
@@ -301,7 +271,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
});
} else if (a.dtype() == float32) {
binary_op<float>(
binary(
a,
b,
out,
@@ -316,7 +286,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else {
eval(inputs, out);
binary(a, b, out, [](auto x, auto y) { return x / y; });
}
}
@@ -327,21 +297,12 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
eval(inputs, out);
}
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpm1f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
throw std::invalid_argument(
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
@@ -390,8 +351,12 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
eval(inputs, out);
throw std::invalid_argument(
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
}
}
@@ -401,7 +366,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
binary(
a,
b,
out,
@@ -416,7 +381,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else {
eval(inputs, out);
binary(a, b, out, [](auto x, auto y) { return x * y; });
}
}
@@ -427,7 +392,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else {
eval(inputs, out);
unary(in, out, [](auto x) { return -x; });
}
}
@@ -514,7 +479,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
eval(inputs, out);
unary(in, out, [](auto x) { return x * x; });
}
}
@@ -540,7 +505,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
binary(
a,
b,
out,
@@ -558,7 +523,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary_op<int>(
binary(
a,
b,
out,
@@ -570,7 +535,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
},
UseDefaultBinaryOp());
} else {
eval(inputs, out);
binary(a, b, out, [](auto x, auto y) { return x - y; });
}
}

View File

@@ -18,61 +18,49 @@ void _qmm_t_4_64(
const float* biases,
int M,
int N,
int K,
int B,
bool batched_w) {
int K) {
constexpr int bits = 4;
constexpr int group_size = 64;
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
int w_els = N * K / pack_factor;
int g_els = w_els * pack_factor / group_size;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
for (int i = 0; i < B; i++) {
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
}
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
}
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
sum += (*x_local) * wf;
x_local++;
}
}
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
*result = simd_reduce_add(sum);
result++;
sum += (*x_local) * wf;
x_local++;
}
}
x += K;
}
if (batched_w) {
w += w_els;
scales += g_els;
biases += g_els;
*result = simd_reduce_add(sum);
result++;
}
x += K;
}
}
@@ -94,10 +82,8 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (condition) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1);
int M = x.shape(-2);
int M = x.size() / K;
int N = out.shape(-1);
int B = x.size() / K / M;
bool batched_w = w.ndim() > 2;
_qmm_t_4_64(
out.data<float>(),
x.data<float>(),
@@ -106,9 +92,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
biases.data<float>(),
M,
N,
K,
B,
batched_w);
K);
} else {
eval(inputs, out);
}

View File

@@ -2,73 +2,86 @@
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, typename VT>
struct MinReduction {
T operator()(const T& a, const T& b) {
return std::min(a, b);
}
VT operator()(VT a, VT b) {
return simd_min(a, b);
}
};
template <typename T, typename VT>
struct MaxReduction {
T operator()(const T& a, const T& b) {
return std::max(a, b);
}
VT operator()(VT a, VT b) {
return simd_max(a, b);
}
};
template <typename T, typename VT>
struct SumReduction {
T operator()(const T& a, const T& b) {
return a + b;
}
VT operator()(VT a, VT b) {
return a + b;
}
};
template <typename T, typename VT, int N, typename Reduction>
struct StridedReduce {
void operator()(const T* x, T* accum, int size, size_t stride) {
Reduction op;
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = op((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = op(*a, *x);
a++;
x++;
}
template <typename T, typename VT, int N>
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
VT val = (*(VT*)x);
*(VT*)a += val;
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a++ += *x++;
}
}
};
}
} // namespace
// TODO: Add proper templates for the strided reduce algorithm so we don't have
// to write max/min/sum etc.
template <typename T, typename VT, int N>
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::max(*a, *x);
a++;
x++;
}
}
}
template <typename T, typename VT, int N>
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::min(*a, *x);
a++;
x++;
}
}
}
template <typename T, typename VT, int N>
void _vectorized_sum(const T* x, T* accum, int size) {
VT _sum = {0};
while (size >= N) {
_sum += (*(VT*)x);
x += N;
size -= N;
}
T sum = _sum[0];
for (int i = 1; i < N; i++) {
sum += _sum[i];
}
*accum += sum;
}
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
@@ -81,11 +94,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
0,
StridedReduce<
float,
simd_float16,
16,
SumReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_sum<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float acc;
vDSP_sve((const float*)x, 1, &acc, size);
@@ -99,11 +111,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
-std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MaxReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_max<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float max;
vDSP_maxv((const float*)x, 1, &max, size);
@@ -117,11 +128,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MinReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_min<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float min;
vDSP_minv((const float*)x, 1, &min, size);

View File

@@ -1,12 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <limits>
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <arm_neon.h>
#endif
#include <simd/math.h>
#include <simd/vector.h>
@@ -33,8 +30,8 @@ namespace {
* Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0.
*/
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
auto x = x_init * 1.442695; // multiply with log_2(e)
inline simd_float16 simd_fast_exp(simd_float16 x) {
x *= 1.442695; // multiply with log_2(e)
simd_float16 ipart, fpart;
simd_int16 epart;
x = simd_clamp(x, -80, 80);
@@ -53,30 +50,28 @@ inline simd_float16 simd_fast_exp(simd_float16 x_init) {
// bitshifting
epart = (simd_int(ipart) + 127) << 23;
// Avoid supressing NaNs
simd_int16 eq = (x_init == x_init);
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
return (*(simd_float16*)&epart) * x;
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/**
* The ARM neon equivalent of the fast exp above.
*/
inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
float16x8_t fpart = vsubq_f16(x, ipart);
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
x = vdupq_n_f16(1.535336188319500e-4f);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
// generate 2**ipart in the floating point representation using integer
// bitshifting
@@ -112,55 +107,6 @@ inline float16_t neon_reduce_add(float16x8_t x) {
return vget_lane_f16(y, 0);
}
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
return vdupq_n_f16(a);
}
VT load(const T* a) {
return vld1q_f16(a);
}
void store(T* dst, VT x) {
vst1q_f16(dst, x);
}
VT max(VT a, VT b) {
return vmaxq_f16(a, b);
}
VT exp(VT x) {
return neon_fast_exp(x);
}
VT add(VT a, VT b) {
return vaddq_f16(a, b);
}
VT sub(VT a, T b) {
return vsubq_f16(a, vdupq_n_f16(b));
}
VT mul(VT a, VT b) {
return vmulq_f16(a, b);
}
VT mul(VT a, T b) {
return vmulq_f16(a, vdupq_n_f16(b));
}
T reduce_max(VT x) {
return neon_reduce_max(x);
}
T reduce_add(VT x) {
return neon_reduce_add(x);
}
};
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
@@ -177,7 +123,7 @@ struct AccelerateSimdOps {
VT max(VT a, VT b) {
return simd_max(a, b);
}
};
VT exp(VT x) {
return simd_fast_exp(x);
@@ -208,7 +154,54 @@ struct AccelerateSimdOps {
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N>
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
return vdupq_n_f16(a);
}
VT load(const T* a) {
return vld1q_f16(a);
}
void store(T* dst, VT x) {
vst1q_f16(dst, x);
}
VT max(VT a, VT b) {
return vmaxq_f16(a, b);
};
VT exp(VT x) {
return neon_fast_exp(x);
}
VT add(VT a, VT b) {
return vaddq_f16(a, b);
}
VT sub(VT a, T b) {
return vsubq_f16(a, vdupq_n_f16(b));
}
VT mul(VT a, VT b) {
return vmulq_f16(a, b);
}
VT mul(VT a, T b) {
return vmulq_f16(a, vdupq_n_f16(b));
}
T reduce_max(VT x) {
return neon_reduce_max(x);
}
T reduce_add(VT x) {
return neon_reduce_add(x);
}
};
template <typename T, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
@@ -225,21 +218,13 @@ void softmax(const array& in, array& out) {
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
size_t s = M;
while (s >= N) {
VT vals;
if constexpr (std::is_same<T, AccT>::value) {
vals = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vals[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vmaximum = ops.max(vals, vmaximum);
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = ops.reduce_max(vmaximum);
T maximum = ops.reduce_max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
maximum = std::max(maximum, *current_in_ptr);
current_in_ptr++;
}
@@ -249,29 +234,18 @@ void softmax(const array& in, array& out) {
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
VT vexp;
if constexpr (std::is_same<T, AccT>::value) {
vexp = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vexp = ops.exp(ops.sub(vexp, maximum));
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, vexp);
}
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
ops.store(current_out_ptr, vexp);
*(VT*)current_out_ptr = vexp;
vnormalizer = ops.add(vnormalizer, vexp);
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = ops.reduce_add(vnormalizer);
T normalizer = ops.reduce_add(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if (std::is_same<T, AccT>::value) {
*current_out_ptr = _exp;
}
T _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = _exp;
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
@@ -280,33 +254,14 @@ void softmax(const array& in, array& out) {
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
} else {
VT vexp;
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
for (int i = 0; i < N; ++i) {
current_out_ptr[i] = vexp[i];
}
current_in_ptr += N;
}
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
*current_out_ptr *= normalizer;
current_out_ptr++;
}
}
@@ -353,33 +308,15 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
"Softmax is defined only for floating point types");
break;
case float32:
softmax<
float,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
in, out);
break;
case float16:
if (precise_) {
softmax<
float16_t,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
eval(inputs, out); // Redirect to common backend for consistency
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
}
softmax<
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
break;
case bfloat16:
eval(inputs, out);

Some files were not shown because too many files have changed in this diff Show More