Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Barron
152092957c Add NF4 quant 2024-06-27 13:16:31 -07:00
319 changed files with 10959 additions and 22369 deletions

View File

@@ -13,62 +13,8 @@ parameters:
test_release: test_release:
type: boolean type: boolean
default: false default: false
linux_release:
type: boolean
default: false
jobs: jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install
command: |
brew install python@3.9
brew install doxygen
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test: linux_build_and_test:
docker: docker:
- image: cimg/python:3.9 - image: cimg/python:3.9
@@ -85,24 +31,19 @@ jobs:
name: Install dependencies name: Install dependencies
command: | command: |
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.2.0 pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \ CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
echo "stubs" echo "stubs"
pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
@@ -111,9 +52,7 @@ jobs:
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
mkdir -p build && cd build mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run: - run:
name: Run CPP tests name: Run CPP tests
command: ./build/tests/tests command: ./build/tests/tests
@@ -131,13 +70,13 @@ jobs:
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
brew install python@3.9 brew install python@3.8
brew install openmpi brew install openmpi
python3.9 -m venv env python3.8 -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.2.0 pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install numpy pip install numpy
pip install torch pip install torch
pip install tensorflow pip install tensorflow
@@ -146,12 +85,11 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Run Python tests name: Run Python tests
@@ -159,7 +97,7 @@ jobs:
source env/bin/activate source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
- run: - run:
name: Build example extension name: Build example extension
command: | command: |
@@ -173,7 +111,7 @@ jobs:
name: Build CPP only name: Build CPP only
command: | command: |
source env/bin/activate source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu` mkdir -p build && cd build && cmake .. && make -j
- run: - run:
name: Run CPP tests name: Run CPP tests
command: | command: |
@@ -183,23 +121,8 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
cd build/ cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \ cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
-DBUILD_SHARED_LIBS=ON \ make -j
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
build_release: build_release:
parameters: parameters:
@@ -226,7 +149,7 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.2.0 pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade setuptools pip install --upgrade setuptools
pip install numpy pip install numpy
pip install twine pip install twine
@@ -236,20 +159,19 @@ jobs:
command: | command: |
source env/bin/activate source env/bin/activate
DEV_RELEASE=1 \ DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v pip install . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
source env/bin/activate source env/bin/activate
pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- run: - run:
name: Build Python package name: Build Python package
command: | command: |
source env/bin/activate source env/bin/activate
<< parameters.build_env >> \ << parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build -w python -m build -w
- when: - when:
condition: << parameters.build_env >> condition: << parameters.build_env >>
@@ -262,7 +184,7 @@ jobs:
- store_artifacts: - store_artifacts:
path: dist/ path: dist/
build_linux_release: build_linux_test_release:
parameters: parameters:
python_version: python_version:
type: string type: string
@@ -291,28 +213,21 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.2.0 pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade setuptools pip install --upgrade setuptools
pip install numpy pip install numpy
pip install auditwheel pip install auditwheel
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v pip install . -v
pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.extra_env >> \ << parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build --wheel python -m build --wheel
auditwheel show dist/* auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64 auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
@@ -330,9 +245,8 @@ workflows:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] xcode_version: ["15.0.0", "15.2.0"]
- linux_build_and_test - linux_build_and_test
- build_documentation
build_pypi_release: build_pypi_release:
when: when:
@@ -349,17 +263,9 @@ workflows:
ignore: /.*/ ignore: /.*/
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"] xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
prb: prb:
when: when:
matches: matches:
@@ -374,7 +280,7 @@ workflows:
requires: [ hold ] requires: [ hold ]
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] xcode_version: ["15.0.0", "15.2.0"]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]
nightly_build: nightly_build:
@@ -386,7 +292,7 @@ workflows:
- build_release: - build_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"] xcode_version: ["15.0.0", "15.2.0"]
weekly_build: weekly_build:
when: when:
@@ -397,17 +303,17 @@ workflows:
- build_release: - build_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] xcode_version: ["15.0.0", "15.2.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
linux_test_release: linux_test_release:
when: when:
and: and:
- equal: [ main, << pipeline.git.branch >> ] - equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.linux_release >> - << pipeline.parameters.test_release >>
jobs: jobs:
- build_linux_release: - build_linux_test_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
extra_env: ["PYPI_RELEASE=1"] extra_env: ["PYPI_RELEASE=1"]

View File

@@ -1,11 +1,11 @@
repos: repos:
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8 rev: v18.1.4
hooks: hooks:
- id: clang-format - id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0 rev: 24.4.2
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
@@ -14,7 +14,3 @@ repos:
- id: isort - id: isort
args: args:
- --profile=black - --profile=black
- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.13
hooks:
- id: cmake-format

View File

@@ -7,18 +7,16 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals: MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
- Juarez Bochi: Fixed bug in cross attention. - Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`. - AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -1,24 +0,0 @@
cff-version: 1.2.0
title: mlx
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Awni
family-names: Hannun
affiliation: Apple
- given-names: Jagrit
family-names: Digani
affiliation: Apple
- given-names: Angelos
family-names: Katharopoulos
affiliation: Apple
- given-names: Ronan
family-names: Collobert
affiliation: Apple
repository-code: 'https://github.com/ml-explore'
abstract: >-
MLX: efficient and flexible machine learning on Apple
silicon
license: MIT

View File

@@ -24,43 +24,35 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION) if(NOT MLX_VERSION)
set(MLX_VERSION 0.20.0) set(MLX_VERSION 0.15.1)
endif() endif()
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
message( message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
)
set(MLX_BUILD_ARM OFF) set(MLX_BUILD_ARM OFF)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC) if(NOT MLX_ENABLE_X64_MAC)
message( message(FATAL_ERROR
FATAL_ERROR "Building for x86_64 on macOS is not supported."
"Building for x86_64 on macOS is not supported." " If you are on an Apple silicon system, check the build"
" If you are on an Apple silicon system, check the build" " documentation for possible fixes: "
" documentation for possible fixes: " "https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
)
else() else()
set(MLX_BUILD_METAL OFF)
message(WARNING "Building for x86_64 arch is not officially supported.") message(WARNING "Building for x86_64 arch is not officially supported.")
endif() endif()
set(MLX_BUILD_METAL OFF)
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif() endif()
else() else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.") message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif() endif()
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
# ----------------------------- Lib ----------------------------- # ----------------------------- Lib -----------------------------
include(FetchContent) include(FetchContent)
@@ -69,59 +61,64 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
if(MLX_BUILD_METAL) if (MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal") find_library(METAL_LIB Metal)
set(FOUNDATION_LIB "-framework Foundation") find_library(FOUNDATION_LIB Foundation)
set(QUARTZ_LIB "-framework QuartzCore") find_library(QUARTZ_LIB QuartzCore)
endif() endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB) if (MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU") message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF) set(MLX_METAL_DEBUG OFF)
elseif(MLX_BUILD_METAL) elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources") message(STATUS "Building METAL sources")
if(MLX_METAL_DEBUG) if (MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG) add_compile_definitions(MLX_METAL_DEBUG)
endif() endif()
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process( execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" OUTPUT_VARIABLE MACOS_VERSION
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY) COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_VERSION} LESS 14.0)
message(
FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif()
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip if (${MACOS_VERSION} GREATER_EQUAL 15.0)
) set(MLX_METAL_VERSION METAL_3_2)
# Get the metal version elseif (${MACOS_VERSION} GREATER_EQUAL 14.2)
execute_process( set(MLX_METAL_VERSION METAL_3_1)
COMMAND elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
zsh "-c" set(MLX_METAL_VERSION METAL_3_0)
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'" else()
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL}) FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
)
FetchContent_MakeAvailable(metal_cpp) FetchContent_MakeAvailable(metal_cpp)
target_include_directories( target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}> mlx PUBLIC
$<INSTALL_INTERFACE:include/metal_cpp>) $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) $<INSTALL_INTERFACE:include/metal_cpp>
)
target_link_libraries(
mlx PUBLIC
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}") add_compile_definitions(${MLX_METAL_VERSION})
endif() endif()
if(MLX_BUILD_CPU) if (MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate) find_library(ACCELERATE_LIBRARY Accelerate)
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY) if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY}) target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
@@ -133,29 +130,32 @@ if(MLX_BUILD_CPU)
# The blas shipped in macOS SDK is not supported, search homebrew for # The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead. # openblas instead.
set(BLA_VENDOR OpenBLAS) set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif() endif()
# Search and link with lapack. # Search and link with lapack.
find_package(LAPACK REQUIRED) find_package(LAPACK REQUIRED)
if(NOT LAPACK_FOUND) if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed") message(FATAL_ERROR "Must have LAPACK installed")
endif() endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/local/opt/openblas/include) /usr/include
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES}) target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old # List blas after lapack otherwise we may accidentally incldue an old version
# version of lapack.h from the include dirs of blas. # of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED) find_package(BLAS REQUIRED)
if(NOT BLAS_FOUND) if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed") message(FATAL_ERROR "Must have BLAS installed")
endif() endif()
# TODO find a cleaner way to do this # TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include find_path(BLAS_INCLUDE_DIRS cblas.h
$ENV{BLAS_HOME}/include) /usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES}) message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
@@ -166,95 +166,96 @@ else()
endif() endif()
find_package(MPI) find_package(MPI)
if(MPI_FOUND) if (MPI_FOUND)
execute_process( execute_process(
COMMAND zsh "-c" "mpirun --version" COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET) COMMAND_ERROR_IS_FATAL ANY
if(${MPI_VERSION} MATCHES ".*Open MPI.*") )
if (${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH}) target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else() else()
set(MPI_FOUND FALSE) message(
message(WARNING "MPI which is not OpenMPI found. Building without MPI.") WARNING
"MPI which is not OpenMPI found. Building without MPI."
)
endif() endif()
endif() endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories( target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}> mlx
$<INSTALL_INTERFACE:include>) PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>
)
FetchContent_Declare( FetchContent_Declare(fmt
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1 GIT_TAG 10.2.1
EXCLUDE_FROM_ALL) EXCLUDE_FROM_ALL
)
FetchContent_MakeAvailable(fmt) FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>) target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
if(MLX_BUILD_PYTHON_BINDINGS) if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.") message(STATUS "Building Python bindings.")
find_package( find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif() endif()
if(MLX_BUILD_TESTS) if (MLX_BUILD_TESTS)
include(CTest) include(CTest)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
endif() endif()
if(MLX_BUILD_EXAMPLES) if (MLX_BUILD_EXAMPLES)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
endif() endif()
if(MLX_BUILD_BENCHMARKS) if (MLX_BUILD_BENCHMARKS)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif() endif()
# ----------------------------- Installation ----------------------------- # ----------------------------- Installation -----------------------------
include(GNUInstallDirs) include(GNUInstallDirs)
# Install library # Install library
install( install(
TARGETS mlx TARGETS mlx
EXPORT MLXTargets EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) )
# Install headers # Install headers
install( install(
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers COMPONENT headers
FILES_MATCHING FILES_MATCHING PATTERN "*.h"
PATTERN "*.h" )
PATTERN "backend/metal/kernels.h" EXCLUDE)
# Install metal dependencies # Install metal dependencies
if(MLX_BUILD_METAL) if (MLX_BUILD_METAL)
# Install metal cpp # Install metal cpp
install( install(
DIRECTORY ${metal_cpp_SOURCE_DIR}/ DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source) COMPONENT metal_cpp_source
)
endif() endif()
@@ -266,24 +267,31 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
install( install(
EXPORT MLXTargets EXPORT MLXTargets
FILE MLXTargets.cmake FILE MLXTargets.cmake
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}) DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
include(CMakePackageConfigHelpers) include(CMakePackageConfigHelpers)
write_basic_package_version_file( write_basic_package_version_file(
${MLX_CMAKE_BUILD_VERSION_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
COMPATIBILITY SameMajorVersion COMPATIBILITY SameMajorVersion
VERSION ${MLX_VERSION}) VERSION ${MLX_VERSION}
)
configure_package_config_file( configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG} ${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
${MLX_CMAKE_BUILD_CONFIG}
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
NO_CHECK_REQUIRED_COMPONENTS_MACRO NO_CHECK_REQUIRED_COMPONENTS_MACRO
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
MLX_CMAKE_INSTALL_MODULE_DIR) )
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG} install(
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}) FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
install(DIRECTORY ${CMAKE_MODULE_PATH}/ install(
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}) DIRECTORY ${CMAKE_MODULE_PATH}/
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)

View File

@@ -6,7 +6,7 @@
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx) [![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning on Apple silicon, MLX is an array framework for machine learning research on Apple silicon,
brought to you by Apple machine learning research. brought to you by Apple machine learning research.
Some key features of MLX include: Some key features of MLX include:

View File

@@ -62,10 +62,17 @@ def matmul(x, y):
def _quant_matmul(x, w, s, b, transpose, group_size, bits): def _quant_matmul(x, w, s, b, transpose, group_size, bits):
ys = [] ys = []
for i in range(10): for i in range(100):
ys.append( ys.append(
mx.quantized_matmul( mx.quantized_matmul(
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits x,
w,
s,
b,
transpose=transpose,
group_size=group_size,
bits=bits,
mode=mx.QuantizationMode.DEFAULT,
) )
) )
mx.eval(ys) mx.eval(ys)
@@ -144,13 +151,6 @@ def reduction(op, axis, x):
mx.eval(ys) mx.eval(ys)
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
def softmax(axis, x): def softmax(axis, x):
ys = [] ys = []
for i in range(100): for i in range(100):
@@ -512,8 +512,5 @@ if __name__ == "__main__":
elif args.benchmark == "selu": elif args.benchmark == "selu":
print(bench(selu, x)) print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else: else:
raise ValueError("Unknown benchmark") raise ValueError("Unknown benchmark")

View File

@@ -1,127 +0,0 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,143 +0,0 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal([10, 256, 256, 3])
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=32):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(10, 3, 256, 256, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 20
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@@ -1,129 +0,0 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,110 +0,0 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,143 +0,0 @@
import time
import mlx.core as mx
import mlx.nn
import mlx.optimizers as opt
import torch
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
mx.set_default_device(mx.cpu)
class BenchNetMLX(mlx.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = mlx.nn.Sequential(
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
mlx.nn.ReLU(),
mlx.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
mlx.nn.ReLU(),
mlx.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def __call__(self, input):
return self.net(input)
benchNet = BenchNetMLX(3)
mx.eval(benchNet.parameters())
optim = opt.Adam(learning_rate=1e-3)
inputs = mx.random.normal(shape)
params = benchNet.parameters()
optim.init(params)
state = [benchNet.state, optim.state]
def loss_fn(params, image):
benchNet.update(params)
pred_image = benchNet(image)
return (pred_image - image).abs().mean()
def step(params, image):
loss, grads = mx.value_and_grad(loss_fn)(params, image)
optim.update(benchNet, grads)
return loss
total_time = 0.0
print("MLX:")
for i in range(steps):
start_time = time.perf_counter()
step(benchNet.parameters(), inputs)
mx.eval(state)
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
device = torch.device("cpu")
class BenchNetTorch(torch.nn.Module):
# simple encoder-decoder net
def __init__(self, in_channels, hidden_channels=16):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv3d(
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
),
torch.nn.ReLU(),
torch.nn.ConvTranspose3d(
hidden_channels, in_channels, kernel_size=3, padding=1
),
)
def forward(self, input):
return self.net(input)
benchNet = BenchNetTorch(3).to(device)
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
inputs = torch.randn(*shape, device=device)
def loss_fn(pred_image, image):
return (pred_image - image).abs().mean()
total_time = 0.0
print("PyTorch:")
for i in range(steps):
start_time = time.perf_counter()
optim.zero_grad()
pred_image = benchNet(inputs)
loss = loss_fn(pred_image, inputs)
loss.backward()
optim.step()
end_time = time.perf_counter()
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
total_time += (end_time - start_time) * 1000
return total_time
def main():
steps = 10
time_mlx = bench_mlx(steps)
time_torch = bench_torch(steps)
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
print(f"total time of MLX: {time_mlx:9.2f} ms")
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
print(f"total time of PyTorch: {time_torch:9.2f} ms")
diff = time_torch / time_mlx - 1.0
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
if __name__ == "__main__":
main()

View File

@@ -1,116 +0,0 @@
import argparse
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 1
N_iter_bench = 10
N_iter_func = 5
mx.set_default_device(mx.cpu)
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
def mx_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_3D
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
@torch.no_grad()
def pt_conv_3D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose3d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
return ys
return pt_conv_3D
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kD * kH * kW * C)
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
f_mx = make_mx_conv_3D(strides, padding, groups)
f_pt = make_pt_conv_3D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose3d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose3d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
)
for dtype in dtypes:
print(
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,135 +0,0 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_transpose_2D
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_transpose_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv_transpose2d(
a, b, stride=strides, padding=padding, groups=groups
)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_transpose_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv_transpose2d(
a_mx, b_mx, stride=strides, padding=padding, groups=groups
)
out_pt = torch.conv_transpose2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,66 +0,0 @@
# Copyright © 2024 Apple Inc.
"""
Run with:
mpirun -n 2 python /path/to/distributed_bench.py
"""
import time
import mlx.core as mx
def time_fn(fn, *args, **kwargs):
msg = kwargs.pop("msg", None)
world = mx.distributed.init()
if world.rank() == 0:
if msg:
print(f"Timing {msg} ...", end=" ")
else:
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
mx.eval(fn(*args, **kwargs))
num_iters = 100
tic = time.perf_counter()
for _ in range(num_iters):
x = mx.eval(fn(*args, **kwargs))
toc = time.perf_counter()
msec = 1e3 * (toc - tic) / num_iters
if world.rank() == 0:
print(f"{msec:.5f} msec")
def time_all_sum():
shape = (4096,)
x = mx.random.uniform(shape=shape)
mx.eval(x)
def sine(x):
for _ in range(20):
x = mx.sin(x)
return x
time_fn(sine, x)
def all_sum_plain(x):
for _ in range(20):
x = mx.distributed.all_sum(x)
return x
time_fn(all_sum_plain, x)
def all_sum_with_sine(x):
for _ in range(20):
x = mx.sin(x)
x = mx.distributed.all_sum(x)
return x
time_fn(all_sum_with_sine, x)
if __name__ == "__main__":
time_all_sum()

View File

@@ -1,84 +0,0 @@
# Copyright © 2024 Apple Inc.
import time
import mlx.core as mx
import numpy as np
def timeit(fn, its=100, args=[]):
for _ in range(5):
fn(*args)
tic = time.perf_counter()
for _ in range(its):
fn(*args)
toc = time.perf_counter()
return 1e3 * (toc - tic) / its
def time_little_einsum_path():
subscripts = "ik,kj->ij"
x = mx.ones((32, 32))
y = mx.ones((32, 32))
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
x = np.array(x)
y = np.array(y)
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
print("Timing little einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_big_einsum_path():
chars = list("abcdefgh")
char_to_dim = {c: v for v, c in enumerate(chars)}
num_inputs = 10
inputs = []
subscripts = []
for _ in range(num_inputs):
subscript = np.random.choice(chars, size=5, replace=False).tolist()
subscripts.append("".join(subscript))
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
subscripts = ",".join(subscripts)
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
inputs = [mx.array(x) for x in inputs]
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
print("Timing big einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_attention():
def regular_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
scores = mx.softmax(scores, axis=-1)
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
mx.eval(output)
def einsum_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
scores = mx.softmax(scores, axis=-1)
output = mx.einsum("ijtu,iujk->itjk", scores, values)
mx.eval(output)
x = mx.random.uniform(shape=(8, 512, 32, 128))
regular_time = timeit(regular_attention, args=(x,))
ein_time = timeit(einsum_attention, args=(x,))
print("Timing einsum attention...")
print(f"Regular ... {regular_time:.3f} ms")
print(f"Einsum ... {ein_time:.3f} ms")
if __name__ == "__main__":
time_little_einsum_path()
time_big_einsum_path()
time_attention()

View File

@@ -1,70 +0,0 @@
import argparse
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def had(x):
y = mx.hadamard_transform(x)
mx.eval(y)
def copy(x):
y = x + 1.0
mx.eval(y)
def run(dtype):
system_size = 2**26
outputs = {}
for test_fn in (had, copy):
for m in [1, 12, 20, 28]:
if test_fn == copy:
key = "copy"
elif m == 1:
key = "had_2^k"
else:
key = "had_m*2^k"
outputs.setdefault(key, {})
for k in range(7, 14):
n = m * 2**k
if n > 2**15:
continue
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
x = mx.array(x_np)
runtime_ms = measure_runtime(test_fn, x=x)
bytes_per_gb = 1e9
ms_per_s = 1e3
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
bandwidth_gb = (
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
)
print(n, bandwidth_gb)
outputs[key][n] = bandwidth_gb
colors = {
"copy": "black",
"had_2^k": "steelblue",
"had_m*2^k": "skyblue",
}
for key, output in outputs.items():
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"bench_{dtype.__name__}.png")
plt.clf()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
dtype = np.float16 if args.fp16 else np.float32
run(dtype)

View File

@@ -9,7 +9,7 @@ from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes): def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx): def scatter(dst, x, idx):
dst[tuple(idx)] = x dst[*idx] = x
mx.eval(dst) mx.eval(dst)
idx = [] idx = []
@@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device): def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def scatter(dst, x, idx, device): def gather(dst, x, idx, device):
dst[tuple(idx)] = x dst[*idx] = x
if device == torch.device("mps"): if device == torch.device("mps"):
torch.mps.synchronize() torch.mps.synchronize()
@@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
x = torch.randn(x_shape, dtype=torch.float32).to(device) x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device) dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device) runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms") print(f"PyTorch: {runtime:.3f}ms")
@@ -54,7 +54,7 @@ if __name__ == "__main__":
(100_000, 64), (100_000, 64),
(1_000_000, 64), (1_000_000, 64),
(100_000,), (100_000,),
(200_000,), (2_000_00,),
(20_000_000,), (20_000_000,),
(10000, 64), (10000, 64),
(100, 64), (100, 64),
@@ -91,6 +91,6 @@ if __name__ == "__main__":
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes): for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20) print("=" * 20)
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}") print(f"X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape) benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device) benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@@ -1,49 +0,0 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
L = 1024
H = 32
H_k = 32 // 4
D = 128
def attention(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)
def sdpa(q, k, v):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(attention, q, k, v)
def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(sdpa, q, k, v)
if __name__ == "__main__":
time_self_attention_sdpa()
time_self_attention_primitives()

View File

@@ -1,21 +1,30 @@
include(CMakeParseArguments) include(CMakeParseArguments)
# ############################################################################## ###############################################################################
# Build metal library # Build metal library
# #
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib # Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS} # from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
# #
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of # Args:
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List # TARGET: Custom target to be added for the metal library
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency # TITLE: Name of the .metallib
# files (like headers) # OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
# SOURCES: List of source files
# INCLUDE_DIRS: List of include dirs
# DEPS: List of dependency files (like headers)
# #
macro(mlx_build_metallib) macro(mlx_build_metallib)
# Parse args # Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(
MTLLIB
""
"${oneValueArgs}"
"${multiValueArgs}"
${ARGN}
)
# Set output # Set output
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib") set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
@@ -26,16 +35,22 @@ macro(mlx_build_metallib)
# Prepare metallib build command # Prepare metallib build command
add_custom_command( add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET} OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND COMMAND xcrun -sdk macosx metal
xcrun -sdk macosx metal "$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>" ${MTLLIB_COMPILE_OPTIONS}
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET} ${MTLLIB_SOURCES}
-o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES} DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
COMMAND_EXPAND_LISTS COMMAND_EXPAND_LISTS
COMMENT "Building ${MTLLIB_TITLE}.metallib" COMMENT "Building ${MTLLIB_TITLE}.metallib"
VERBATIM) VERBATIM
)
# Add metallib custom target # Add metallib custom target
add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET}) add_custom_target(
${MTLLIB_TARGET}
DEPENDS
${MTLLIB_BUILD_TARGET}
)
endmacro(mlx_build_metallib) endmacro(mlx_build_metallib)

View File

@@ -1,4 +1,3 @@
sphinx sphinx
breathe breathe
sphinx-book-theme sphinx-book-theme
mlx

View File

@@ -60,7 +60,6 @@ html_theme_options = {
}, },
} }
html_favicon = html_theme_options["logo"]["image_light"]
# -- Options for HTMLHelp output --------------------------------------------- # -- Options for HTMLHelp output ---------------------------------------------
@@ -84,15 +83,3 @@ def setup(app):
# -- Options for LaTeX output ------------------------------------------------ # -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")] latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
latex_elements = {
"preamble": r"""
\usepackage{enumitem}
\setlistdepth{5}
\setlist[itemize,1]{label=$\bullet$}
\setlist[itemize,2]{label=$\bullet$}
\setlist[itemize,3]{label=$\bullet$}
\setlist[itemize,4]{label=$\bullet$}
\setlist[itemize,5]{label=$\bullet$}
\renewlist{itemize}{itemize}{5}
""",
}

View File

@@ -1,427 +0,0 @@
.. _custom_metal_kernels:
Custom Metal Kernels
====================
MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example
--------------
Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp",
input_names=["inp"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
.. note::
We are only required to pass the body of the Metal kernel in ``source``.
The full function signature will be generated using:
* The shapes/dtypes of ``inputs``
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
so we will add ``const device float16_t* inp`` to the signature.
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
in ``source``.
* The list of ``output_dtypes``
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
so we add ``device float16_t* out``.
* Template parameters passed using ``template``
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
and instantiates the template with ``custom_kernel_myexp_float<float>``.
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
These will be added as function arguments.
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
Putting this all together, the generated function signature for ``myexp`` is as follows:
.. code-block:: cpp
template <typename T>
[[kernel]] void custom_kernel_myexp_float(
const device float16_t* inp [[buffer(0)]],
device float16_t* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
uint elem = thread_position_in_grid.x;
T tmp = inp[elem];
out[elem] = metal::exp(tmp);
}
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
Using Shape/Strides
-------------------
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
when indexing.
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
input array ``a`` if any are present in ``source``.
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python
def exp_elementwise(a: mx.array):
source = """
uint elem = thread_position_in_grid.x;
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
T tmp = inp[loc];
// Output arrays are always row contiguous
out[elem] = metal::exp(tmp);
"""
kernel = mx.fast.metal_kernel(
name="myexp_strided",
input_names=["inp"],
output_names=["out"],
source=source
)
outputs = kernel(
inputs=[a],
template=[("T", mx.float32)],
grid=(a.size, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
ensure_row_contiguous=False,
)
return outputs[0]
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
# make non-contiguous
a = a[::2]
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))
Complex Example
-----------------------------
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
We'll start with the following MLX implementation using standard ops:
.. code-block:: python
def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1
iy_ne = iy_nw
ix_sw = ix_nw
iy_sw = iy_nw + 1
ix_se = ix_nw + 1
iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel:
.. code-block:: python
@mx.custom_function
def grid_sample(x, grid):
assert x.ndim == 4, "`x` must be 4D."
assert grid.ndim == 4, "`grid` must be 4D."
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
out_shape = (B, gN, gM, C)
assert D == 2, "Last dim of `grid` must be size 2."
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride;
int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""
kernel = mx.fast.metal_kernel(
name="grid_sample",
input_names=["x", "grid"],
output_names=["out"],
source=source,
)
outputs = kernel(
inputs=[x, grid],
template=[("T", x.dtype)],
output_shapes=[out_shape],
output_dtypes=[x.dtype],
grid=(np.prod(out_shape), 1, 1),
threadgroup=(256, 1, 1),
)
return outputs[0]
For a reasonably sized input such as:
.. code-block:: python
x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement:
``55.7ms -> 6.7ms => 8x speed up``
Grid Sample VJP
---------------
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra ``mx.fast.metal_kernel`` features:
* ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
* ``atomic_outputs=True``
Designate all of the kernel outputs as ``atomic`` in the function signature.
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
We can then implement the backwards pass as follows:
.. code-block:: python
@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
x, grid = primals
B, _, _, C = x.shape
_, gN, gM, D = grid.shape
assert D == 2, "Last dim of `grid` must be size 2."
source = """
uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C;
int h_stride = W * w_stride;
int b_stride = H * h_stride;
uint grid_idx = elem / C_padded * 2;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix);
int iy_nw = floor(iy);
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C_padded / gH / gW * b_stride;
int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T gix = T(0);
T giy = T(0);
if (channel_idx < C) {
int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_nw = x[offset];
gix -= I_nw * (iy_se - iy) * cot;
giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T I_ne = x[offset];
gix += I_ne * (iy_sw - iy) * cot;
giy -= I_ne * (ix - ix_sw) * cot;
}
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_sw = x[offset];
gix -= I_sw * (iy - iy_ne) * cot;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_se = x[offset];
gix += I_se * (iy - iy_nw) * cot;
giy += I_se * (ix - ix_nw) * cot;
}
}
T gix_mult = W / 2;
T giy_mult = H / 2;
// Reduce across each simdgroup first.
// This is much faster than relying purely on atomics.
gix = simd_sum(gix);
giy = simd_sum(giy);
if (thread_index_in_simdgroup == 0) {
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
# pad the output channels to simd group size
# so that our `simd_sum`s don't overlap.
simdgroup_size = 32
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
grid_size = B * gN * gM * C_padded
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp:
``676.4ms -> 16.7ms => 40x speed up``

View File

@@ -486,8 +486,9 @@ below.
std::ostringstream kname; std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out); kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available // Make sure the metal library is available and look for it
d.register_library("mlx_ext"); // in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext"); auto kernel = d.get_kernel(kname.str(), "mlx_ext");

View File

@@ -15,7 +15,7 @@ module to concisely define the model architecture.
Attention layer Attention layer
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^
We will start with the Llama attention layer which notably uses the RoPE We will start with the llama attention layer which notably uses the RoPE
positional encoding. [1]_ In addition, our attention layer will optionally use a positional encoding. [1]_ In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to key/value cache that will be concatenated with the provided keys and values to
support efficient inference. support efficient inference.

View File

@@ -64,7 +64,7 @@ set:
Next, setup the problem parameters and load the data. To load the data, you need our Next, setup the problem parameters and load the data. To load the data, you need our
`mnist data loader `mnist data loader
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which <https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
we will import as ``mnist``. we will import as `mnist`.
.. code-block:: python .. code-block:: python

View File

@@ -85,4 +85,3 @@ are the CPU and GPU.
dev/extensions dev/extensions
dev/metal_debugger dev/metal_debugger
dev/custom_metal_kernels

View File

@@ -14,7 +14,7 @@ silicon computer is
To install from PyPI you must meet the following requirements: To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon) - Using an M series chip (Apple silicon)
- Using a native Python >= 3.9 - Using a native Python >= 3.8
- macOS >= 13.5 - macOS >= 13.5
.. note:: .. note::
@@ -70,36 +70,36 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
.. code-block:: shell
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
Then simply build and install MLX using pip: Then simply build and install MLX using pip:
.. code-block:: shell .. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install . env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
For developing, install the package with development dependencies, and use an For developing use an editable install:
editable install:
.. code-block:: shell .. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]" env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
Once the development dependencies are installed, you can build faster with: To make sure the install is working run the tests with:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
Run the tests with:
.. code-block:: shell .. code-block:: shell
pip install ".[testing]"
python -m unittest discover python/tests python -m unittest discover python/tests
Optional: Install stubs to enable auto completions and type checking from your Optional: Install stubs to enable auto completions and type checking from your IDE:
IDE:
.. code-block:: shell .. code-block:: shell
pip install ".[dev]"
python setup.py generate_stubs python setup.py generate_stubs
C++ API C++ API
@@ -195,7 +195,7 @@ GGUF, you can do:
.. code-block:: shell .. code-block:: shell
cmake .. \ cmake ..
-DCMAKE_BUILD_TYPE=MinSizeRel \ -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \ -DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \ -DMLX_BUILD_CPU=OFF \
@@ -240,7 +240,7 @@ x86 Shell
.. _build shell: .. _build shell:
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively. Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm, To fix this, find the application in Finder (``/Applications`` for iTerm,
@@ -264,4 +264,4 @@ Also check that cmake is using the correct architecture:
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"`` If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported." but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cache with ``rm -rf build/`` and try again. wipe your build cahce with ``rm -rf build/`` and try again.

View File

@@ -24,7 +24,6 @@ Array
array.any array.any
array.argmax array.argmax
array.argmin array.argmin
array.conj
array.cos array.cos
array.cummax array.cummax
array.cummin array.cummin
@@ -53,10 +52,8 @@ Array
array.sqrt array.sqrt
array.square array.square
array.squeeze array.squeeze
array.std
array.sum
array.swapaxes array.swapaxes
array.sum
array.transpose array.transpose
array.T array.T
array.var array.var
array.view

View File

@@ -17,6 +17,3 @@ made available.
init init
all_sum all_sum
all_gather all_gather
send
recv
recv_like

View File

@@ -12,5 +12,3 @@ Fast
layer_norm layer_norm
rope rope
scaled_dot_product_attention scaled_dot_product_attention
affine_quantize
metal_kernel

View File

@@ -9,12 +9,7 @@ Linear Algebra
:toctree: _autosummary :toctree: _autosummary
inv inv
tri_inv
norm norm
cholesky cholesky
cholesky_inv
cross
qr qr
svd svd
eigvalsh
eigh

View File

@@ -14,7 +14,6 @@ Metal
get_cache_memory get_cache_memory
set_memory_limit set_memory_limit
set_cache_limit set_cache_limit
set_wired_limit
clear_cache clear_cache
start_capture start_capture
stop_capture stop_capture

View File

@@ -13,7 +13,6 @@ simple functions.
:template: nn-module-template.rst :template: nn-module-template.rst
elu elu
celu
gelu gelu
gelu_approx gelu_approx
gelu_fast_approx gelu_fast_approx

View File

@@ -13,18 +13,13 @@ Layers
AvgPool1d AvgPool1d
AvgPool2d AvgPool2d
BatchNorm BatchNorm
CELU
Conv1d Conv1d
Conv2d Conv2d
Conv3d Conv3d
ConvTranspose1d
ConvTranspose2d
ConvTranspose3d
Dropout Dropout
Dropout2d Dropout2d
Dropout3d Dropout3d
Embedding Embedding
ELU
GELU GELU
GLU GLU
GroupNorm GroupNorm
@@ -36,8 +31,6 @@ Layers
LayerNorm LayerNorm
LeakyReLU LeakyReLU
Linear Linear
LogSigmoid
LogSoftmax
LSTM LSTM
MaxPool1d MaxPool1d
MaxPool2d MaxPool2d
@@ -53,7 +46,6 @@ Layers
RoPE RoPE
SELU SELU
Sequential Sequential
Sigmoid
SiLU SiLU
SinusoidalPositionalEncoding SinusoidalPositionalEncoding
Softmin Softmin

View File

@@ -44,10 +44,6 @@ Operations
convolve convolve
conv1d conv1d
conv2d conv2d
conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
conv_general conv_general
cos cos
cosh cosh
@@ -61,8 +57,6 @@ Operations
diagonal diagonal
divide divide
divmod divmod
einsum
einsum_path
equal equal
erf erf
erfinv erfinv
@@ -78,11 +72,8 @@ Operations
gather_qmm gather_qmm
greater greater
greater_equal greater_equal
hadamard_transform
identity identity
imag
inner inner
isfinite
isclose isclose
isinf isinf
isnan isnan
@@ -112,7 +103,6 @@ Operations
minimum minimum
moveaxis moveaxis
multiply multiply
nan_to_num
negative negative
not_equal not_equal
ones ones
@@ -122,17 +112,14 @@ Operations
pad pad
power power
prod prod
put_along_axis
quantize quantize
quantized_matmul quantized_matmul
radians radians
real
reciprocal reciprocal
remainder remainder
repeat repeat
reshape reshape
right_shift right_shift
roll
round round
rsqrt rsqrt
save save

View File

@@ -31,41 +31,6 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state. # Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state) mx.eval(model.parameters(), optimizer.state)
Saving and Loading
------------------
To serialize an optimizer, save its state. To load an optimizer, load and set
the saved state. Here's a simple example:
.. code-block:: python
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
import mlx.optimizers as optim
optimizer = optim.Adam(learning_rate=1e-2)
# Perform some updates with the optimizer
model = {"w" : mx.zeros((5, 5))}
grads = {"w" : mx.ones((5, 5))}
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
parameters are not. A good rule of thumb is if the parameter can be scheduled
then it will be included in the optimizer state.
.. toctree:: .. toctree::
optimizers/optimizer optimizers/optimizer

View File

@@ -44,5 +44,3 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
split split
truncated_normal truncated_normal
uniform uniform
laplace
permutation

View File

@@ -10,7 +10,6 @@ Transforms
eval eval
compile compile
custom_function
disable_compile disable_compile
enable_compile enable_compile
grad grad

View File

@@ -136,6 +136,13 @@ Now make an array, and benchmark both functions:
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster. five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging Debugging
--------- ---------

View File

@@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop:
ys = mx.random.uniform(shape=(100, 4096)) ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys): def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[0])] return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
Instead you can use :func:`vmap` to automatically vectorize the addition: Instead you can use :func:`vmap` to automatically vectorize the addition:
@@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition:
# Vectorize over the second dimension of x and the # Vectorize over the second dimension of x and the
# first dimension of y # first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1)) vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
The ``in_axes`` parameter can be used to specify which dimensions of the The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify corresponding input to vectorize over. Similarly, use ``out_axes`` to specify

View File

@@ -77,7 +77,7 @@ from the GPU. Performing bounds checking for array indices before launching the
kernel would be extremely inefficient. kernel would be extremely inefficient.
Indexing with boolean masks is something that MLX may support in the future. In Indexing with boolean masks is something that MLX may support in the future. In
general, MLX has limited support for operations for which output general, MLX has limited support for operations for which outputs
*shapes* are dependent on input *data*. Other examples of these types of *shapes* are dependent on input *data*. Other examples of these types of
operations which MLX does not yet support include :func:`numpy.nonzero` and the operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`. single input version of :func:`numpy.where`.

View File

@@ -109,7 +109,7 @@ Here is a concrete example:
An important behavior to be aware of is when the graph will be implicitly An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you ``print`` an array, convert it to an evaluated. Anytime you ``print`` an array, convert it to an
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`, :obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
saving functions) will also evaluate the array. saving functions) will also evaluate the array.

View File

@@ -11,14 +11,10 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies ----------------------------- # ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED) find_package(MLX CONFIG REQUIRED)
find_package( find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
Python 3.8
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
@@ -28,10 +24,16 @@ find_package(nanobind CONFIG REQUIRED)
add_library(mlx_ext) add_library(mlx_ext)
# Add sources # Add sources
target_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp) target_sources(
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
# Add include headers # Add include headers
target_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
# Link to mlx # Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx) target_link_libraries(mlx_ext PUBLIC mlx)
@@ -41,32 +43,27 @@ target_link_libraries(mlx_ext PUBLIC mlx)
# Build metallib # Build metallib
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
mlx_build_metallib( mlx_build_metallib(
TARGET TARGET mlx_ext_metallib
mlx_ext_metallib TITLE mlx_ext
TITLE SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
mlx_ext INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
SOURCES OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal )
INCLUDE_DIRS
${PROJECT_SOURCE_DIR}
${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
add_dependencies(mlx_ext mlx_ext_metallib) add_dependencies(
mlx_ext
mlx_ext_metallib
)
endif() endif()
# ----------------------------- Python Bindings ----------------------------- # ----------------------------- Python Bindings -----------------------------
nanobind_add_module( nanobind_add_module(
_ext _ext
NB_STATIC NB_STATIC STABLE_ABI LTO NOMINSIZE
STABLE_ABI NB_DOMAIN mlx
LTO ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
NOMINSIZE )
NB_DOMAIN
mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
target_link_libraries(_ext PRIVATE mlx_ext) target_link_libraries(_ext PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS) if(BUILD_SHARED_LIBS)

View File

@@ -249,8 +249,9 @@ void Axpby::eval_gpu(
kname << (contiguous_kernel ? "contiguous_" : "general_"); kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out); kname << type_to_name(out);
// Make sure the metal library is available // Make sure the metal library is available and look for it
d.register_library("mlx_ext"); // in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext"); auto kernel = d.get_kernel(kname.str(), "mlx_ext");

View File

@@ -2,7 +2,7 @@
requires = [ requires = [
"setuptools>=42", "setuptools>=42",
"cmake>=3.24", "cmake>=3.24",
"mlx>=0.18.0", "mlx>=0.9.0",
"nanobind==2.2.0", "nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.24 cmake>=3.24
mlx>=0.18.1 mlx>=0.9.0
nanobind==2.2.0 nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4

View File

@@ -13,6 +13,7 @@ if __name__ == "__main__":
cmdclass={"build_ext": extension.CMakeBuild}, cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"], packages=["mlx_sample_extensions"],
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev": []},
zip_safe=False, zip_safe=False,
python_requires=">=3.8", python_requires=">=3.8",
) )

View File

@@ -1,24 +1,25 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
)
if(MLX_BUILD_CPU) if (MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else() else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
@@ -26,15 +27,17 @@ endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_ACCELERATE) if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
elseif(MLX_BUILD_CPU) elseif(MLX_BUILD_CPU)
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp) PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
)
endif() endif()
if(MLX_BUILD_METAL) if (MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else() else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)

View File

@@ -23,22 +23,11 @@ void free(Buffer buffer) {
} }
Buffer CommonAllocator::malloc(size_t size, bool) { Buffer CommonAllocator::malloc(size_t size, bool) {
void* ptr = std::malloc(size + sizeof(size_t)); return Buffer{std::malloc(size)};
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
} }
void CommonAllocator::free(Buffer buffer) { void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr()); std::free(buffer.raw_ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
} }
Buffer malloc_or_wait(size_t size) { Buffer malloc_or_wait(size_t size) {

View File

@@ -41,7 +41,6 @@ class Allocator {
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0; virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual void free(Buffer buffer) = 0; virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
Allocator() = default; Allocator() = default;
Allocator(const Allocator& other) = delete; Allocator(const Allocator& other) = delete;
@@ -58,7 +57,6 @@ class CommonAllocator : public Allocator {
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private: private:
CommonAllocator() = default; CommonAllocator() = default;

View File

@@ -17,10 +17,6 @@ bool in_tracing() {
return detail::InTracing::in_tracing(); return detail::InTracing::in_tracing();
} }
bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace } // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */) array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
@@ -95,34 +91,18 @@ void array::detach() {
array_desc_->primitive = nullptr; array_desc_->primitive = nullptr;
} }
bool array::is_available() const {
if (status() == Status::available) {
return true;
} else if (status() == Status::evaluated && event().is_signaled()) {
set_status(Status::available);
return true;
}
return false;
}
void array::wait() {
if (!is_available()) {
event().wait();
set_status(Status::available);
}
}
void array::eval() { void array::eval() {
// Ensure the array is ready to be read // Ensure the array is ready to be read
if (status() == Status::unscheduled) { if (status() == Status::scheduled) {
event().wait();
set_status(Status::available);
} else if (status() == Status::unscheduled) {
mlx::core::eval({*this}); mlx::core::eval({*this});
} else {
wait();
} }
} }
bool array::is_tracer() const { bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing() || retain_graph(); return array_desc_->is_tracer && in_tracing();
} }
void array::set_data(allocator::Buffer buffer, deleter_t d) { void array::set_data(allocator::Buffer buffer, deleter_t d) {
@@ -178,10 +158,8 @@ void array::move_shared_buffer(
array_desc_->flags = flags; array_desc_->flags = flags;
array_desc_->data_size = data_size; array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset; auto char_offset = sizeof(char) * itemsize() * offset;
auto data_ptr = other.array_desc_->data_ptr; array_desc_->data_ptr = static_cast<void*>(
other.array_desc_->data_ptr = nullptr; static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
} }
void array::move_shared_buffer(array other) { void array::move_shared_buffer(array other) {
@@ -193,11 +171,10 @@ array::~array() {
return; return;
} }
// Ignore arrays that might be detached during eval // Ignore arrays that will be detached
if (status() == array::Status::scheduled) { if (status() != array::Status::unscheduled) {
return; return;
} }
// Break circular reference for non-detached arrays with siblings // Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) { if (auto n = siblings().size(); n > 0) {
bool do_detach = true; bool do_detach = true;
@@ -260,38 +237,25 @@ array::ArrayDesc::~ArrayDesc() {
// This calls recursively the destructor and can result in stack overflow, we // This calls recursively the destructor and can result in stack overflow, we
// instead put them in a vector and destroy them one at a time resulting in a // instead put them in a vector and destroy them one at a time resulting in a
// max stack depth of 2. // max stack depth of 2.
if (inputs.empty()) {
return;
}
std::vector<std::shared_ptr<ArrayDesc>> for_deletion; std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) { for (array& a : inputs) {
std::unordered_map<std::uintptr_t, array> input_map; if (a.array_desc_.use_count() == 1) {
for (array& a : ad.inputs) { for_deletion.push_back(std::move(a.array_desc_));
if (a.array_desc_) {
input_map.insert({a.id(), a});
for (auto& s : a.siblings()) {
input_map.insert({s.id(), s});
}
}
} }
ad.inputs.clear(); }
for (auto& [_, a] : input_map) {
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
};
append_deletable_inputs(*this);
while (!for_deletion.empty()) { while (!for_deletion.empty()) {
// top is going to be deleted at the end of the block *after* the arrays // top is going to be deleted at the end of the block *after* the arrays
// with inputs have been moved into the vector // with inputs have been moved into the vector
auto top = std::move(for_deletion.back()); auto top = std::move(for_deletion.back());
for_deletion.pop_back(); for_deletion.pop_back();
append_deletable_inputs(*top);
for (array& a : top->inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
} }
} }

View File

@@ -5,6 +5,7 @@
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <optional>
#include <vector> #include <vector>
#include "mlx/allocator.h" #include "mlx/allocator.h"
@@ -219,23 +220,11 @@ class array {
}; };
struct Flags { struct Flags {
// True iff there are no gaps in the underlying data. Each item // True if there are no gaps in the underlying data. Each item
// in the underlying data buffer belongs to at least one index. // in the underlying data buffer belongs to at least one index.
//
// True iff:
// prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()
bool contiguous : 1; bool contiguous : 1;
// True iff:
// strides[-1] == 1 and
// all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in
// range(ndim - 1))
bool row_contiguous : 1; bool row_contiguous : 1;
// True iff:
// strides[0] == 1 and
// all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in
// range(1, ndim))
bool col_contiguous : 1; bool col_contiguous : 1;
}; };
@@ -303,16 +292,7 @@ class array {
return array_desc_->flags; return array_desc_->flags;
} }
/** The size (in elements) of the underlying buffer the array points to. /** The size (in elements) of the underlying buffer the array points to. */
*
* This can be different than the actual size of the array if the array has
* been broadcast or irregularly strided. If ``first`` is the offset into
* the data buffer of the first element of the array (i.e. the offset
* corresponding to ``arr[0, 0, ...]``) and last is the offset into the
* data buffer of the last element of the array (i.e. the offset
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
**/
size_t data_size() const { size_t data_size() const {
return array_desc_->data_size; return array_desc_->data_size;
} }
@@ -324,10 +304,6 @@ class array {
return array_desc_->data->buffer; return array_desc_->data->buffer;
} }
size_t buffer_size() const {
return allocator::allocator().size(buffer());
}
// Return a copy of the shared pointer // Return a copy of the shared pointer
// to the array::Data struct // to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const { std::shared_ptr<Data> data_shared_ptr() const {
@@ -344,33 +320,11 @@ class array {
return static_cast<T*>(array_desc_->data_ptr); return static_cast<T*>(array_desc_->data_ptr);
} }
enum Status { enum Status { unscheduled, scheduled, available };
// The ouptut of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`.
unscheduled,
// The ouptut of a computation which has been scheduled but `eval_*` has bool is_available() const {
// not yet been called on the array's primitive. A possible return status() == Status::available;
// status of `x` in `auto x = a + b; eval(x);` }
scheduled,
// The array's `eval_*` function has been run, but the computation is not
// necessarily complete. The array will have memory allocated and if it is
// not a tracer then it will be detached from the graph.
evaluated,
// If the array is the output of a computation then the computation
// is complete. Constant arrays are always available (e.g. `array({1, 2,
// 3})`)
available
};
// Check if the array is safe to read.
bool is_available() const;
// Wait on the array to be available. After this `is_available` returns
// `true`.
void wait();
Status status() const { Status status() const {
return array_desc_->status; return array_desc_->status;
@@ -459,6 +413,8 @@ class array {
void* data_ptr{nullptr}; void* data_ptr{nullptr};
// The size in elements of the data buffer the array accesses // The size in elements of the data buffer the array accesses
// This can be different than the actual size of the array if it
// has been broadcast or irregularly strided.
size_t data_size; size_t data_size;
// Contains useful meta data about the array // Contains useful meta data about the array
@@ -610,4 +566,6 @@ inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
template <typename... T> template <typename... T>
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>; using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
enum QuantizationMode { DEFAULT, NF4 };
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,8 +1,10 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
)

View File

@@ -1,9 +1,9 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cassert> #include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h> #include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@@ -2,7 +2,8 @@
#include <cassert> #include <cassert>
#include <Accelerate/Accelerate.h> #include <vecLib/BNNS/bnns.h>
#include <vecLib/cblas_new.h>
#include "mlx/backend/accelerate/utils.h" #include "mlx/backend/accelerate/utils.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"

View File

@@ -3,7 +3,8 @@
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <Accelerate/Accelerate.h> #include <vecLib/vDSP.h>
#include <vecLib/vForce.h>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
@@ -36,7 +37,7 @@ DEFAULT(Ceil)
DEFAULT(Concatenate) DEFAULT(Concatenate)
DEFAULT(Conjugate) DEFAULT(Conjugate)
DEFAULT(Copy) DEFAULT(Copy)
DEFAULT_MULTI(CustomTransforms) DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends) DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod) DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements) DEFAULT(NumberOfElements)
@@ -50,7 +51,6 @@ DEFAULT(GatherMM)
DEFAULT(GatherQMM) DEFAULT(GatherQMM)
DEFAULT(Greater) DEFAULT(Greater)
DEFAULT(GreaterEqual) DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less) DEFAULT(Less)
DEFAULT(LessEqual) DEFAULT(LessEqual)
DEFAULT(Load) DEFAULT(Load)
@@ -81,7 +81,6 @@ DEFAULT_MULTI(SVD)
DEFAULT(Transpose) DEFAULT(Transpose)
DEFAULT(Inverse) DEFAULT(Inverse)
DEFAULT(Cholesky) DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) { void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
@@ -103,7 +102,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == float32) { if (a.dtype() == float32) {
binary_op<float>( binary(
a, a,
b, b,
out, out,
@@ -118,7 +117,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
}); });
} else if (a.dtype() == int32) { } else if (a.dtype() == int32) {
binary_op<int>( binary(
a, a,
b, b,
out, out,
@@ -133,7 +132,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n); vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
}); });
} else { } else {
eval(inputs, out); binary(a, b, out, [](auto x, auto y) { return x + y; });
} }
} }
@@ -288,7 +287,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == int32) { if (a.dtype() == int32) {
binary_op<int>( binary(
a, a,
b, b,
out, out,
@@ -301,7 +300,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n); vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
}); });
} else if (a.dtype() == float32) { } else if (a.dtype() == float32) {
binary_op<float>( binary(
a, a,
b, b,
out, out,
@@ -316,7 +315,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
}); });
} else { } else {
eval(inputs, out); binary(a, b, out, [](auto x, auto y) { return x / y; });
} }
} }
@@ -327,8 +326,12 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out); set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else { } else {
eval(inputs, out); throw std::invalid_argument(
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
} }
} }
@@ -390,8 +393,12 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size(); auto size = in.data_size();
vvlog1pf( vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else { } else {
eval(inputs, out); throw std::invalid_argument(
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
} }
} }
@@ -401,7 +408,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == float32) { if (a.dtype() == float32) {
binary_op<float>( binary(
a, a,
b, b,
out, out,
@@ -416,7 +423,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
}); });
} else { } else {
eval(inputs, out); binary(a, b, out, [](auto x, auto y) { return x * y; });
} }
} }
@@ -427,7 +434,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out); set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size()); vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else { } else {
eval(inputs, out); unary(in, out, [](auto x) { return -x; });
} }
} }
@@ -514,7 +521,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size(); auto size = in.data_size();
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size); vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else { } else {
eval(inputs, out); unary(in, out, [](auto x) { return x * x; });
} }
} }
@@ -540,7 +547,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == float32) { if (a.dtype() == float32) {
binary_op<float>( binary(
a, a,
b, b,
out, out,
@@ -558,7 +565,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
}); });
} else if (a.dtype() == int32) { } else if (a.dtype() == int32) {
binary_op<int>( binary(
a, a,
b, b,
out, out,
@@ -570,7 +577,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
}, },
UseDefaultBinaryOp()); UseDefaultBinaryOp());
} else { } else {
eval(inputs, out); binary(a, b, out, [](auto x, auto y) { return x - y; });
} }
} }

View File

@@ -18,61 +18,49 @@ void _qmm_t_4_64(
const float* biases, const float* biases,
int M, int M,
int N, int N,
int K, int K) {
int B,
bool batched_w) {
constexpr int bits = 4; constexpr int bits = 4;
constexpr int group_size = 64; constexpr int group_size = 64;
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits; constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
int w_els = N * K / pack_factor; for (int m = 0; m < M; m++) {
int g_els = w_els * pack_factor / group_size; const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
for (int i = 0; i < B; i++) { for (int n = 0; n < N; n++) {
for (int m = 0; m < M; m++) { const simd_float16* x_local = (simd_float16*)x;
const uint32_t* w_local = w; simd_float16 sum = 0;
const float* scales_local = scales; for (int k = 0; k < K; k += group_size) {
const float* biases_local = biases; float scale = *scales_local++;
float bias = *biases_local++;
for (int n = 0; n < N; n++) { for (int kw = 0; kw < packs_in_group; kw += 2) {
const simd_float16* x_local = (simd_float16*)x; // TODO: vectorize this properly
simd_float16 sum = 0; simd_uint16 wi;
for (int k = 0; k < K; k += group_size) { for (int e = 0; e < 2; e++) {
float scale = *scales_local++; uint32_t wii = *w_local++;
float bias = *biases_local++; for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
for (int kw = 0; kw < packs_in_group; kw += 2) { wii >>= bits;
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
}
} }
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
sum += (*x_local) * wf;
x_local++;
} }
} simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
*result = simd_reduce_add(sum); sum += (*x_local) * wf;
result++; x_local++;
}
} }
x += K; *result = simd_reduce_add(sum);
} result++;
if (batched_w) {
w += w_els;
scales += g_els;
biases += g_els;
} }
x += K;
} }
} }
@@ -94,10 +82,8 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (condition) { if (condition) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1); int K = x.shape(-1);
int M = x.shape(-2); int M = x.size() / K;
int N = out.shape(-1); int N = out.shape(-1);
int B = x.size() / K / M;
bool batched_w = w.ndim() > 2;
_qmm_t_4_64( _qmm_t_4_64(
out.data<float>(), out.data<float>(),
x.data<float>(), x.data<float>(),
@@ -106,9 +92,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
biases.data<float>(), biases.data<float>(),
M, M,
N, N,
K, K);
B,
batched_w);
} else { } else {
eval(inputs, out); eval(inputs, out);
} }

View File

@@ -2,8 +2,8 @@
#include <cassert> #include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h> #include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/reduce.h" #include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@@ -3,10 +3,7 @@
#include <cassert> #include <cassert>
#include <limits> #include <limits>
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <arm_neon.h> #include <arm_neon.h>
#endif
#include <simd/math.h> #include <simd/math.h>
#include <simd/vector.h> #include <simd/vector.h>
@@ -33,8 +30,8 @@ namespace {
* Note: The implementation below is a general fast exp. There could be faster * Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0. * implementations for numbers strictly < 0.
*/ */
inline simd_float16 simd_fast_exp(simd_float16 x_init) { inline simd_float16 simd_fast_exp(simd_float16 x) {
auto x = x_init * 1.442695; // multiply with log_2(e) x *= 1.442695; // multiply with log_2(e)
simd_float16 ipart, fpart; simd_float16 ipart, fpart;
simd_int16 epart; simd_int16 epart;
x = simd_clamp(x, -80, 80); x = simd_clamp(x, -80, 80);
@@ -53,30 +50,28 @@ inline simd_float16 simd_fast_exp(simd_float16 x_init) {
// bitshifting // bitshifting
epart = (simd_int(ipart) + 127) << 23; epart = (simd_int(ipart) + 127) << 23;
// Avoid supressing NaNs return (*(simd_float16*)&epart) * x;
simd_int16 eq = (x_init == x_init);
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
} }
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/** /**
* The ARM neon equivalent of the fast exp above. * The ARM neon equivalent of the fast exp above.
*/ */
inline float16x8_t neon_fast_exp(float16x8_t x) { inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e) x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14 x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14 x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f)))); float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
float16x8_t fpart = vsubq_f16(x, ipart); float16x8_t fpart = vsubq_f16(x, ipart);
x = vdupq_n_f16(float16_t(1.535336188319500e-4f)); x = vdupq_n_f16(1.535336188319500e-4f);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart); x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart); x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart); x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart); x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart); x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart); x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
// generate 2**ipart in the floating point representation using integer // generate 2**ipart in the floating point representation using integer
// bitshifting // bitshifting
@@ -112,55 +107,6 @@ inline float16_t neon_reduce_add(float16x8_t x) {
return vget_lane_f16(y, 0); return vget_lane_f16(y, 0);
} }
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
return vdupq_n_f16(a);
}
VT load(const T* a) {
return vld1q_f16(a);
}
void store(T* dst, VT x) {
vst1q_f16(dst, x);
}
VT max(VT a, VT b) {
return vmaxq_f16(a, b);
}
VT exp(VT x) {
return neon_fast_exp(x);
}
VT add(VT a, VT b) {
return vaddq_f16(a, b);
}
VT sub(VT a, T b) {
return vsubq_f16(a, vdupq_n_f16(b));
}
VT mul(VT a, VT b) {
return vmulq_f16(a, b);
}
VT mul(VT a, T b) {
return vmulq_f16(a, vdupq_n_f16(b));
}
T reduce_max(VT x) {
return neon_reduce_max(x);
}
T reduce_add(VT x) {
return neon_reduce_add(x);
}
};
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename T, typename VT> template <typename T, typename VT>
struct AccelerateSimdOps { struct AccelerateSimdOps {
VT init(T a) { VT init(T a) {
@@ -208,6 +154,53 @@ struct AccelerateSimdOps {
} }
}; };
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
return vdupq_n_f16(a);
}
VT load(const T* a) {
return vld1q_f16(a);
}
void store(T* dst, VT x) {
vst1q_f16(dst, x);
}
VT max(VT a, VT b) {
return vmaxq_f16(a, b);
}
VT exp(VT x) {
return neon_fast_exp(x);
}
VT add(VT a, VT b) {
return vaddq_f16(a, b);
}
VT sub(VT a, T b) {
return vsubq_f16(a, vdupq_n_f16(b));
}
VT mul(VT a, VT b) {
return vmulq_f16(a, b);
}
VT mul(VT a, T b) {
return vmulq_f16(a, vdupq_n_f16(b));
}
T reduce_max(VT x) {
return neon_reduce_max(x);
}
T reduce_add(VT x) {
return neon_reduce_add(x);
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N> template <typename T, typename AccT, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) { void softmax(const array& in, array& out) {
Ops ops; Ops ops;
@@ -369,16 +362,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
AccelerateSimdOps<float, simd_float16>, AccelerateSimdOps<float, simd_float16>,
16>(in, out); 16>(in, out);
} else { } else {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
softmax< softmax<
float16_t, float16_t,
float16_t, float16_t,
float16x8_t, float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>, NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out); 8>(in, out);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
eval(inputs, out); // Redirect to common backend for consistency
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} }
break; break;
case bfloat16: case bfloat16:

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include <Accelerate/Accelerate.h> #include <vecLib/BNNS/bnns.h>
#include "mlx/dtype.h" #include "mlx/dtype.h"
namespace mlx::core { namespace mlx::core {

View File

@@ -1,4 +1,5 @@
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER}) set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE) set(CLANG TRUE)
else() else()
@@ -6,57 +7,71 @@ else()
endif() endif()
add_custom_command( add_custom_command(
OUTPUT compiled_preamble.cpp OUTPUT compiled_preamble.cpp
COMMAND COMMAND /bin/bash
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER} ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${PROJECT_SOURCE_DIR} ${CLANG} ${COMPILER}
DEPENDS make_compiled_preamble.sh ${PROJECT_SOURCE_DIR}
compiled_preamble.h ${CLANG}
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h)
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp) DEPENDS make_compiled_preamble.sh
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h
)
add_custom_target(
cpu_compiled_preamble
DEPENDS compiled_preamble.cpp
)
add_dependencies(mlx cpu_compiled_preamble) add_dependencies(mlx cpu_compiled_preamble)
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp )
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if(IOS) if (IOS)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp) target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp
)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp) target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
)
endif() endif()

View File

@@ -43,15 +43,13 @@ void set_binary_op_output_data(
array& out, array& out,
BinaryOpType bopt, BinaryOpType bopt,
bool donate_with_move = false) { bool donate_with_move = false) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) { switch (bopt) {
case BinaryOpType::ScalarScalar: case BinaryOpType::ScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
if (b_donatable) { if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(b); out.move_shared_buffer(b);
} else { } else {
@@ -66,7 +64,7 @@ void set_binary_op_output_data(
} }
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
if (a_donatable) { if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(a); out.move_shared_buffer(a);
} else { } else {
@@ -81,13 +79,13 @@ void set_binary_op_output_data(
} }
break; break;
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
if (a_donatable) { if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(a); out.move_shared_buffer(a);
} else { } else {
out.copy_shared_buffer(a); out.copy_shared_buffer(a);
} }
} else if (b_donatable) { } else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(b); out.move_shared_buffer(b);
} else { } else {
@@ -102,14 +100,16 @@ void set_binary_op_output_data(
} }
break; break;
case BinaryOpType::General: case BinaryOpType::General:
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) { if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(a); out.move_shared_buffer(a);
} else { } else {
out.copy_shared_buffer(a); out.copy_shared_buffer(a);
} }
} else if ( } else if (
b_donatable && b.flags().row_contiguous && b.size() == out.size()) { b.is_donatable() && b.flags().row_contiguous &&
b.itemsize() == out.itemsize() && b.size() == out.size()) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(b); out.move_shared_buffer(b);
} else { } else {
@@ -122,7 +122,19 @@ void set_binary_op_output_data(
} }
} }
struct UseDefaultBinaryOp {}; struct UseDefaultBinaryOp {
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
// Should we throw? This should normally never be called.
assert(false);
}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
// Should we throw? This should normally never be called.
assert(false);
}
};
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
struct DefaultVectorScalar { struct DefaultVectorScalar {
@@ -138,6 +150,18 @@ struct DefaultVectorScalar {
a++; a++;
} }
} }
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
T scalar = *b;
while (size-- > 0) {
auto dst = op(*a, scalar);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
a++;
}
}
}; };
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
@@ -154,6 +178,18 @@ struct DefaultScalarVector {
b++; b++;
} }
} }
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
T scalar = *a;
while (size-- > 0) {
auto dst = op(scalar, *b);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
b++;
}
}
}; };
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
@@ -170,110 +206,204 @@ struct DefaultVectorVector {
b++; b++;
} }
} }
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
while (size-- > 0) {
auto dst = op(*a, *b);
*dst_a = dst.first;
*dst_b = dst.second;
dst_a++;
dst_b++;
a++;
b++;
}
}
}; };
template <typename T, typename U, typename Op, int D, bool Strided> template <typename T, typename U, typename Op>
void binary_op_dims( void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
const T* a, const T* a_ptr = a.data<T>();
const T* b, const T* b_ptr = b.data<T>();
U* out, U* dst = out.data<U>();
Op op, size_t a_idx = 0;
const std::vector<int>& shape, size_t b_idx = 0;
const std::vector<size_t>& a_strides, for (size_t i = 0; i < out.size(); ++i) {
const std::vector<size_t>& b_strides, dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
const std::vector<size_t>& out_strides, a_idx += a.strides()[0];
int axis) { b_idx += b.strides()[0];
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
binary_op_dims<T, U, Op, D - 1, Strided>(
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
} else {
if constexpr (Strided) {
op(a, b, out, stride_out);
} else {
*out = op(*a, *b);
}
}
out += stride_out;
a += stride_a;
b += stride_b;
} }
} }
template <typename T, typename U, bool Strided, typename Op> template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; i++) {
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
dst += stride;
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
dst += stride;
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op) {
switch (out.ndim()) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out, op);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out, op);
return;
case 3:
binary_op_dims3<T, U, Op>(a, b, out, op);
return;
case 4:
binary_op_dims4<T, U, Op>(a, b, out, op);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims( void binary_op_dispatch_dims(
const array& a, const array& a,
const array& b, const array& b,
array& out, array& out,
Op op, Op op,
int dim, int dim,
const std::vector<int>& shape, int stride) {
const std::vector<size_t>& a_strides, // Number of dimensions to loop over for vectorized ops
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
switch (dim) { switch (dim) {
case 1: case 1:
binary_op_dims<T, U, Op, 1, Strided>( binary_op_dims1<T, U, Op>(a, b, out, op, stride);
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
case 2: case 2:
binary_op_dims<T, U, Op, 2, Strided>( binary_op_dims2<T, U, Op>(a, b, out, op, stride);
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 3:
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
} }
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3); const T* a_ptr = a.data<T>();
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3); const T* b_ptr = b.data<T>();
size_t stride = out_strides[dim - 4]; U* dst = out.data<U>();
for (size_t elem = 0; elem < a.size(); elem += stride) { for (size_t i = 0; i < out.size(); i += stride) {
binary_op_dims<T, U, Op, 3, Strided>( int a_idx = elem_to_loc(i, a.shape(), a.strides());
a_ptr + a_it.loc, int b_idx = elem_to_loc(i, b.shape(), b.strides());
b_ptr + b_it.loc, op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
out_ptr + elem, dst += stride;
op,
shape,
a_strides,
b_strides,
out_strides,
dim - 3);
a_it.step();
b_it.step();
} }
} }
@@ -320,33 +450,29 @@ void binary_op(
} }
// General computation so let's try to optimize // General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
const auto& a_strides = new_strides[0];
const auto& b_strides = new_strides[1];
const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after // Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) { auto& strides = out.strides();
int d = arr_strides.size() - 1; auto leftmost_rc_dim = [&strides](const array& arr) {
for (; d >= 0 && arr_strides[d] == strides[d]; d--) { int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
} }
return d + 1; return d + 1;
}; };
auto a_rc_dim = leftmost_rc_dim(a_strides); auto a_rc_dim = leftmost_rc_dim(a);
auto b_rc_dim = leftmost_rc_dim(b_strides); auto b_rc_dim = leftmost_rc_dim(b);
// Get the left-most dim such that the array is a broadcasted "scalar" after // Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) { auto leftmost_s_dim = [](const array& arr) {
int d = arr_strides.size() - 1; int d = arr.ndim() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) { for (; d >= 0 && arr.strides()[d] == 0; d--) {
} }
return d + 1; return d + 1;
}; };
auto a_s_dim = leftmost_s_dim(a_strides); auto a_s_dim = leftmost_s_dim(a);
auto b_s_dim = leftmost_s_dim(b_strides); auto b_s_dim = leftmost_s_dim(b);
auto ndim = new_shape.size(); auto ndim = out.ndim();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim; int dim = ndim;
@@ -368,27 +494,27 @@ void binary_op(
// Can be sure dim > 0 since otherwise we would have used one of the fully // Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not // contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity. // correspond to the underlying contiguity.
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) { if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General; bopt = BinaryOpType::General;
dim = ndim; dim = ndim;
} else {
stride = strides[dim - 1];
} }
switch (bopt) { switch (bopt) {
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true>( binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true>( binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true>( binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
break; break;
default: default:
binary_op_dispatch_dims<T, U, false>( binary_op_dispatch_dims<T, U>(a, b, out, op);
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
break; break;
} }
} }
@@ -405,9 +531,9 @@ void binary_op(
// TODO: The following mess of constexpr evaluations can probably be achieved // TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler? // with template specializations and overloading. Would it be simpler?
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) { if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) { if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) { if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?) // All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>( binary_op<T, T>(
a, a,
@@ -428,8 +554,7 @@ void binary_op(
DefaultVectorScalar<T, T, Op>(op), DefaultVectorScalar<T, T, Op>(op),
opvv); opvv);
} }
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>:: } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
value) {
// opsv and opvv were UseDefaultBinaryOp // opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>( binary_op<T, T>(
a, a,
@@ -444,8 +569,7 @@ void binary_op(
binary_op<T, T>( binary_op<T, T>(
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv); a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
} }
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>:: } else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) { if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp // opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>( binary_op<T, T>(
@@ -461,8 +585,7 @@ void binary_op(
binary_op<T, T>( binary_op<T, T>(
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv); a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
} }
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>:: } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
value) {
// opvv was UseDefaultBinaryOp // opvv was UseDefaultBinaryOp
binary_op<T, T>( binary_op<T, T>(
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op)); a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));

View File

@@ -9,43 +9,168 @@ namespace mlx::core {
namespace { namespace {
template <typename T, typename U, typename Op, int D> template <typename T, typename U, typename Op>
void binary_op_dims( void binary_op_dims1(
const T* a, const array& a,
const T* b, const array& b,
U* out_a, array& out_a,
U* out_b, array& out_b,
Op op, Op op) {
const std::vector<int>& shape, const T* a_ptr = a.data<T>();
const std::vector<size_t>& a_strides, const T* b_ptr = b.data<T>();
const std::vector<size_t>& b_strides, U* dst_a = out_a.data<U>();
const std::vector<size_t>& out_strides, U* dst_b = out_b.data<U>();
int axis) { size_t a_idx = 0;
auto stride_a = a_strides[axis]; size_t b_idx = 0;
auto stride_b = b_strides[axis]; for (size_t i = 0; i < out_a.size(); ++i) {
auto stride_out = out_strides[axis]; auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
auto N = shape[axis]; dst_a[i] = dst.first;
dst_b[i] = dst.second;
a_idx += a.strides()[0];
b_idx += b.strides()[0];
}
}
for (int i = 0; i < N; i++) { template <typename T, typename U, typename Op>
if constexpr (D > 1) { void binary_op_dims1(
binary_op_dims<T, U, Op, D - 1>( const array& a,
a, const array& b,
b, array& out_a,
out_a, array& out_b,
out_b, Op op,
op, int stride) {
shape, const T* a_ptr = a.data<T>();
a_strides, const T* b_ptr = b.data<T>();
b_strides, U* dst_a = out_a.data<U>();
out_strides, U* dst_b = out_b.data<U>();
axis + 1); size_t a_idx = 0;
} else { size_t b_idx = 0;
std::tie(*out_a, *out_b) = op(*a, *b); for (size_t i = 0; i < a.shape()[0]; i++) {
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
dst_a += stride;
dst_b += stride;
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[1];
b_idx += b.strides()[1];
} }
a += stride_a; a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b += stride_b; b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
out_a += stride_out; }
out_b += stride_out; }
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
dst_a += stride;
dst_b += stride;
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims3(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[2];
b_idx += b.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims4(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
dst_a[out_idx] = dst.first;
dst_b[out_idx++] = dst.second;
a_idx += a.strides()[3];
b_idx += b.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
} }
} }
@@ -56,160 +181,352 @@ void binary_op_dispatch_dims(
array& out_a, array& out_a,
array& out_b, array& out_b,
Op op) { Op op) {
auto [shape, strides] = collapse_contiguous_dims( switch (out_a.ndim()) {
a.shape(), {a.strides(), b.strides(), out_a.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>();
int ndim = shape.size();
switch (ndim) {
case 1: case 1:
binary_op_dims<T, U, Op, 1>( binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op);
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return; return;
case 2: case 2:
binary_op_dims<T, U, Op, 2>( binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op);
a_ptr, return;
b_ptr, case 3:
out_a_ptr, binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op);
out_b_ptr, return;
op, case 4:
shape, binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op);
a_strides,
b_strides,
out_strides,
0);
return; return;
} }
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2); const T* a_ptr = a.data<T>();
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2); const T* b_ptr = b.data<T>();
size_t stride = out_strides[ndim - 3]; U* dst_a = out_a.data<U>();
for (size_t elem = 0; elem < a.size(); elem += stride) { U* dst_b = out_b.data<U>();
binary_op_dims<T, U, Op, 2>( for (size_t i = 0; i < out_a.size(); i++) {
a_ptr + a_it.loc, int a_idx = elem_to_loc(i, a.shape(), a.strides());
b_ptr + b_it.loc, int b_idx = elem_to_loc(i, b.shape(), b.strides());
out_a_ptr + elem, std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]);
out_b_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
} }
} }
template <typename T, typename U = T, typename Op> template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
int dim,
int stride) {
// Number of dimensions to loop over for vectorized ops
switch (dim) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst_a = out_a.data<U>();
U* dst_b = out_b.data<U>();
for (size_t i = 0; i < out_a.size(); i += stride) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
dst_a += stride;
dst_b += stride;
}
}
template <
typename T,
typename U,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
opsv(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
opvs(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
opvv(
a.data<T>(),
b.data<T>(),
out_a.data<U>(),
out_b.data<U>(),
out_a.size());
return;
}
// General computation so let's try to optimize
// Get the left-most dim such that the array is row contiguous after
auto& strides = out_a.strides();
auto leftmost_rc_dim = [&strides](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a);
auto b_rc_dim = leftmost_rc_dim(b);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a);
auto b_s_dim = leftmost_s_dim(b);
auto ndim = out_a.ndim();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
break;
default:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
break;
}
}
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
void binary_op(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv and opvs were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
opvv);
}
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opvs was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
outputs[0],
outputs[1],
op,
opsv,
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// All ops provided
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
}
}
template <typename T, typename Op>
void binary_op( void binary_op(
const array& a, const array& a,
const array& b, const array& b,
std::vector<array>& outputs, std::vector<array>& outputs,
Op op) { Op op) {
auto bopt = get_binary_op_type(a, b); DefaultScalarVector<T, T, Op> opsv(op);
auto& out_a = outputs[0]; DefaultVectorScalar<T, T, Op> opvs(op);
auto& out_b = outputs[1]; DefaultVectorVector<T, T, Op> opvv(op);
set_binary_op_output_data(a, b, out_a, bopt); binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
return;
}
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_a_ptr = out_a.data<U>();
auto out_b_ptr = out_b.data<U>();
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
} else if (bopt == BinaryOpType::ScalarVector) {
for (size_t i = 0; i < b.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
} else if (bopt == BinaryOpType::VectorScalar) {
for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
}
} else { // VectorVector
for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
b_ptr++;
}
}
} }
template <typename Op> template <typename... Ops>
void binary( void binary(
const array& a, const array& a,
const array& b, const array& b,
std::vector<array>& outputs, std::vector<array>& outputs,
Op op) { Ops... ops) {
switch (outputs[0].dtype()) { switch (outputs[0].dtype()) {
case bool_: case bool_:
binary_op<bool>(a, b, outputs, op); binary_op<bool>(a, b, outputs, ops...);
break; break;
case uint8: case uint8:
binary_op<uint8_t>(a, b, outputs, op); binary_op<uint8_t>(a, b, outputs, ops...);
break; break;
case uint16: case uint16:
binary_op<uint16_t>(a, b, outputs, op); binary_op<uint16_t>(a, b, outputs, ops...);
break; break;
case uint32: case uint32:
binary_op<uint32_t>(a, b, outputs, op); binary_op<uint32_t>(a, b, outputs, ops...);
break; break;
case uint64: case uint64:
binary_op<uint64_t>(a, b, outputs, op); binary_op<uint64_t>(a, b, outputs, ops...);
break; break;
case int8: case int8:
binary_op<int8_t>(a, b, outputs, op); binary_op<int8_t>(a, b, outputs, ops...);
break; break;
case int16: case int16:
binary_op<int16_t>(a, b, outputs, op); binary_op<int16_t>(a, b, outputs, ops...);
break; break;
case int32: case int32:
binary_op<int32_t>(a, b, outputs, op); binary_op<int32_t>(a, b, outputs, ops...);
break; break;
case int64: case int64:
binary_op<int64_t>(a, b, outputs, op); binary_op<int64_t>(a, b, outputs, ops...);
break; break;
case float16: case float16:
binary_op<float16_t>(a, b, outputs, op); binary_op<float16_t>(a, b, outputs, ops...);
break; break;
case float32: case float32:
binary_op<float>(a, b, outputs, op); binary_op<float>(a, b, outputs, ops...);
break; break;
case bfloat16: case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op); binary_op<bfloat16_t>(a, b, outputs, ops...);
break; break;
case complex64: case complex64:
binary_op<complex64_t>(a, b, outputs, op); binary_op<complex64_t>(a, b, outputs, ops...);
break; break;
} }
} }

View File

@@ -2,12 +2,46 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h" #include "mlx/linalg.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core { namespace mlx::core {
namespace {
// Delegate to the Cholesky factorization taking into account differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int spotrf_wrapper(char uplo, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1));
#else
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
} // namespace
void cholesky_impl(const array& a, array& factor, bool upper) { void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that // Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric: // the matrix should be symmetric:
@@ -32,14 +66,7 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization. // Compute Cholesky factorization.
int info; int info = spotrf_wrapper(uplo, matrix, N);
MLX_LAPACK_FUNC(spotrf)
(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
// TODO: We do nothing when the matrix is not positive semi-definite // TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how // because throwing an error would result in a crash. If we figure out how

View File

@@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(inputs[0]); out.copy_shared_buffer(inputs[0]);
} }
void CustomTransforms::eval( void CustomVJP::eval(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() > outputs.size()); assert(inputs.size() > outputs.size());
@@ -156,7 +156,8 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
} }
// Firstly let's collapse all the contiguous dimensions of the input // Firstly let's collapse all the contiguous dimensions of the input
auto [shape, strides] = collapse_contiguous_dims(in); auto [shape, _strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
// If shapes fit exactly in the contiguous dims then no copy is necessary so // If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check. // let's check.

View File

@@ -18,8 +18,7 @@ void print_constant(std::ostream& os, const array& x) {
case complex64: case complex64:
return print_complex_constant<complex64_t>(os, x); return print_complex_constant<complex64_t>(os, x);
case int8: case int8:
os << static_cast<int32_t>(x.item<int8_t>()); return print_int_constant<int8_t>(os, x);
return;
case int16: case int16:
return print_int_constant<int16_t>(os, x); return print_int_constant<int16_t>(os, x);
case int32: case int32:
@@ -27,8 +26,7 @@ void print_constant(std::ostream& os, const array& x) {
case int64: case int64:
return print_int_constant<int64_t>(os, x); return print_int_constant<int64_t>(os, x);
case uint8: case uint8:
os << static_cast<uint32_t>(x.item<uint8_t>()); return print_int_constant<uint8_t>(os, x);
return;
case uint16: case uint16:
return print_int_constant<uint16_t>(os, x); return print_int_constant<uint16_t>(os, x);
case uint32: case uint32:
@@ -207,8 +205,8 @@ void compiled_allocate_outputs(
// - Donatable // - Donatable
// - Correct size // - Correct size
// - Not a constant // - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() && if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) { if (move_buffers) {
outputs[o].move_shared_buffer( outputs[o].move_shared_buffer(

View File

@@ -2,10 +2,7 @@
#include <dlfcn.h> #include <dlfcn.h>
#include <filesystem> #include <filesystem>
#include <fstream>
#include <list> #include <list>
#include <mutex>
#include <shared_mutex>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h" #include "mlx/backend/common/compiled_preamble.h"
@@ -14,30 +11,6 @@
namespace mlx::core { namespace mlx::core {
struct CompilerCache {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
if (!lib) {
std::ostringstream msg;
msg << "Could not load C++ shared library " << dlerror();
throw std::runtime_error(msg.str());
}
}
~DLib() {
dlclose(lib);
}
void* lib;
};
// Statics to cache compiled libraries and functions
std::list<DLib> libs;
std::unordered_map<std::string, void*> kernels;
std::shared_mutex mtx;
};
static CompilerCache cache{};
// GPU compile is always available if the GPU is available and since we are in // GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available. // this file CPU compile is also available.
namespace detail { namespace detail {
@@ -53,19 +26,32 @@ std::string get_temp_file(const std::string& name) {
// Return a pointer to a compiled function // Return a pointer to a compiled function
void* compile( void* compile(
const std::string& kernel_name, const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) { const std::string& source_code = "") {
{ struct DLib {
std::shared_lock lock(cache.mtx); DLib(const std::string& libname) {
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { lib = dlopen(libname.c_str(), RTLD_NOW);
return it->second; if (!lib) {
std::ostringstream msg;
msg << "Could not load C++ shared library " << dlerror();
throw std::runtime_error(msg.str());
}
} }
}
std::unique_lock lock(cache.mtx); ~DLib() {
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { dlclose(lib);
}
void* lib;
};
// Statics to cache compiled libraries and functions
static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels;
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second; return it->second;
} }
std::string source_code = source_builder(); if (source_code.empty()) {
return nullptr;
}
std::string kernel_file_name; std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255 // Deal with long kernel names. Maximum length for files on macOS is 255
@@ -103,8 +89,8 @@ void* compile(
source_file.close(); source_file.close();
std::ostringstream build_command; std::ostringstream build_command;
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '" build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
<< source_file_path << "' -o '" << shared_lib_path << "'"; << source_file_path << " -o " << shared_lib_path;
std::string build_command_str = build_command.str(); std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str()); auto return_code = system(build_command_str.c_str());
if (return_code) { if (return_code) {
@@ -116,10 +102,10 @@ void* compile(
} }
// load library // load library
cache.libs.emplace_back(shared_lib_path); libs.emplace_back(shared_lib_path);
// Load function // Load function
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str()); void* fun = dlsym(libs.back().lib, kernel_name.c_str());
if (!fun) { if (!fun) {
std::ostringstream msg; std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function " msg << "[Compile::eval_cpu] Failed to load compiled function "
@@ -127,7 +113,7 @@ void* compile(
<< dlerror(); << dlerror();
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
cache.kernels.insert({kernel_name, fun}); kernels.insert({kernel_name, fun});
return fun; return fun;
} }
@@ -329,7 +315,10 @@ void Compiled::eval_cpu(
} }
// Get the function // Get the function
auto fn_ptr = compile(kernel_name, [&]() { auto fn_ptr = compile(kernel_name);
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
std::ostringstream kernel; std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl; kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl; kernel << "extern \"C\" {" << std::endl;
@@ -344,8 +333,10 @@ void Compiled::eval_cpu(
ndim); ndim);
// Close extern "C" // Close extern "C"
kernel << "}" << std::endl; kernel << "}" << std::endl;
return kernel.str();
}); // Compile and get function pointer
fn_ptr = compile(kernel_name, kernel.str());
}
compiled_allocate_outputs( compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false); inputs, outputs, inputs_, constant_ids_, contiguous, false);

View File

@@ -3,8 +3,13 @@
#include <cassert> #include <cassert>
#include <numeric> #include <numeric>
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -679,32 +684,6 @@ void dispatch_slow_conv_3D(
// Explicit gemm conv // Explicit gemm conv
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
template <typename T>
void flip_spatial_dims_inplace(array& wt) {
T* x = wt.data<T>();
size_t out_channels = wt.shape(0);
size_t in_channels = wt.shape(-1);
// Calculate the total size of the spatial dimensions
int spatial_size = 1;
for (int d = 1; d < wt.ndim() - 1; ++d) {
spatial_size *= wt.shape(d);
}
for (size_t i = 0; i < out_channels; i++) {
T* top = x + i * spatial_size * in_channels;
T* bottom =
x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels;
for (size_t j = 0; j < spatial_size / 2; j++) {
for (size_t k = 0; k < in_channels; k++) {
std::swap(top[k], bottom[k]);
}
top += in_channels;
bottom -= in_channels;
}
}
}
void explicit_gemm_conv_1D_cpu( void explicit_gemm_conv_1D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
@@ -931,8 +910,7 @@ void explicit_gemm_conv_ND_cpu(
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation) {
const bool flip) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>( const auto iDim = std::vector<int>(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
@@ -1022,14 +1000,6 @@ void explicit_gemm_conv_ND_cpu(
copy(wt, gemm_wt, ctype); copy(wt, gemm_wt, ctype);
} }
if (flip) {
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
copy(gemm_wt, gemm_wt_, CopyType::Vector);
flip_spatial_dims_inplace<float>(gemm_wt_);
gemm_wt = gemm_wt_;
}
if (out.dtype() != float32) { if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
@@ -1072,15 +1042,10 @@ void conv_1D_cpu(
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
bool flip) { bool flip) {
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu( return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation); in, wt, out, padding, wt_strides, wt_dilation);
} }
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip);
}
return dispatch_slow_conv_1D( return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
@@ -1095,13 +1060,6 @@ void conv_2D_cpu(
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
bool flip) { bool flip) {
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
in_dilation[1] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip);
}
return dispatch_slow_conv_2D( return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} }
@@ -1115,14 +1073,6 @@ void conv_3D_cpu(
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
bool flip) { bool flip) {
const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 &&
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
groups == 1) {
return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip);
}
return dispatch_slow_conv_3D( return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} }
@@ -1175,7 +1125,7 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
else { else {
std::ostringstream msg; std::ostringstream msg;
msg << "[Convolution::eval] Convolution currently only supports" msg << "[Convolution::eval] Convolution currently only supports"
<< " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2 << " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2
<< " spatial dimensions"; << " spatial dimensions";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }

View File

@@ -4,7 +4,6 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -26,117 +25,252 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
} }
template <typename SrcT, typename DstT, typename StrideT, int D> template <typename SrcT, typename DstT, typename stride_t>
inline void copy_dims( void copy_general_dim1(
const SrcT* src, const array& src,
DstT* dst, array& dst,
const std::vector<int>& shape, const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides, const std::vector<stride_t>& i_strides,
const std::vector<StrideT>& o_strides, int64_t i_offset) {
int axis) { const SrcT* src_ptr = src.data<SrcT>();
auto stride_src = i_strides[axis]; DstT* dst_ptr = dst.data<DstT>();
auto stride_dst = o_strides[axis]; stride_t src_idx = i_offset;
auto N = shape[axis]; stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int i = 0; i < N; i++) { dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
if constexpr (D > 1) { src_idx += i_strides[0];
copy_dims<SrcT, DstT, StrideT, D - 1>(
src, dst, shape, i_strides, o_strides, axis + 1);
} else {
*dst = static_cast<DstT>(*src);
}
src += stride_src;
dst += stride_dst;
} }
} }
template <typename SrcT, typename DstT, typename StrideT> template <typename SrcT, typename DstT>
inline void copy_general_dim1(const array& src, array& dst) {
return copy_general_dim1<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim2(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[1];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim2(const array& src, array& dst) {
return copy_general_dim2<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim3(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[2];
}
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim3(const array& src, array& dst) {
return copy_general_dim3<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim4(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
for (int ii = 0; ii < data_shape[3]; ++ii) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[3];
}
src_idx += i_strides[2] - i_strides[3] * data_shape[3];
}
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim4(const array& src, array& dst) {
return copy_general_dim4<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
switch (src.ndim()) {
case 1:
copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return;
case 2:
copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return;
case 3:
copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return;
case 4:
copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return;
}
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) {
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
}
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
return copy_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
inline void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
return copy_general<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
}
template <typename SrcT, typename DstT, typename stride_t, int D>
inline void copy_general_general_dims(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
if constexpr (D > 1) {
int axis = src.ndim() - D;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
for (int i = 0; i < N; i++) {
copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
i_offset += stride_src;
o_offset += stride_dst;
}
} else {
int axis = src.ndim() - 1;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
DstT* dst_ptr = dst.data<DstT>() + o_offset;
for (int i = 0; i < N; i++) {
*dst_ptr = static_cast<DstT>(*src_ptr);
src_ptr += stride_src;
dst_ptr += stride_dst;
}
}
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_general( void copy_general_general(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides, const std::vector<stride_t>& i_strides,
const std::vector<StrideT>& o_strides, const std::vector<stride_t>& o_strides,
int64_t i_offset, stride_t i_offset,
int64_t o_offset) { stride_t o_offset) {
if (data_shape.empty()) { switch (src.ndim()) {
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset)); case 1:
auto dst_ptr = dst.data<DstT>() + o_offset; copy_general_general_dims<SrcT, DstT, stride_t, 1>(
*dst_ptr = val; src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return; return;
case 2:
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 3:
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 4:
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 5:
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
} }
auto [shape, strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides}); int size = std::accumulate(
auto src_ptr = src.data<SrcT>() + i_offset; data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>());
auto dst_ptr = dst.data<DstT>() + o_offset; for (int i = 0; i < src.size(); i += size) {
int ndim = shape.size(); stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
if (ndim == 1) { stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
copy_dims<SrcT, DstT, StrideT, 1>( copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0); src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
return;
} else if (ndim == 2) {
copy_dims<SrcT, DstT, StrideT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 3) {
copy_dims<SrcT, DstT, StrideT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
}
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
StrideT stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
for (StrideT elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, StrideT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
strides[0],
strides[1],
ndim - 3);
in.step();
out.step();
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) { inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT, size_t>( return copy_general_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0); src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
} }
template <typename SrcT, typename DstT, typename StrideT>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>&,
int64_t i_offset,
int64_t o_offset) {
copy_general_general<SrcT, DstT, StrideT>(
src,
dst,
data_shape,
i_strides,
make_contiguous_strides<StrideT>(data_shape),
i_offset,
o_offset);
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT, size_t>(
src,
dst,
src.shape(),
src.strides(),
make_contiguous_strides<size_t>(src.shape()),
0,
0);
}
template <typename SrcT, typename DstT, typename... Args> template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
switch (ctype) { switch (ctype) {
@@ -151,7 +285,6 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
return; return;
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...); copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
return;
} }
} }
@@ -252,7 +385,7 @@ inline void copy_inplace_dispatch(
} // namespace } // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) { void copy_inplace(const array& src, array& dst, CopyType ctype) {
copy_inplace_dispatch(src, dst, ctype); return copy_inplace_dispatch(src, dst, ctype);
} }
void copy(const array& src, array& dst, CopyType ctype) { void copy(const array& src, array& dst, CopyType ctype) {
@@ -282,20 +415,20 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype); copy_inplace(src, dst, ctype);
} }
template <typename StrideT> template <typename stride_t>
void copy_inplace( void copy_inplace(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides, const std::vector<stride_t>& i_strides,
const std::vector<StrideT>& o_strides, const std::vector<stride_t>& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype) { CopyType ctype) {
switch (ctype) { switch (ctype) {
case CopyType::General: case CopyType::General:
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
copy_inplace_dispatch( return copy_inplace_dispatch(
src, src,
dst, dst,
ctype, ctype,
@@ -304,24 +437,15 @@ void copy_inplace(
o_strides, o_strides,
i_offset, i_offset,
o_offset); o_offset);
break;
case CopyType::Scalar: case CopyType::Scalar:
case CopyType::Vector: case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype); return copy_inplace_dispatch(src, dst, ctype);
} }
} }
template void copy_inplace<size_t>( template <>
const array& src, void copy_inplace<int64_t>(
array& dst,
const std::vector<int>& data_shape,
const std::vector<size_t>& i_strides,
const std::vector<size_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
template void copy_inplace<int64_t>(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const std::vector<int>& data_shape,
@@ -329,6 +453,24 @@ template void copy_inplace<int64_t>(
const std::vector<int64_t>& o_strides, const std::vector<int64_t>& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype); CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,10 +1,14 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring> #include <cstring>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -48,7 +52,7 @@ DEFAULT(Convolution)
DEFAULT(Copy) DEFAULT(Copy)
DEFAULT(Cos) DEFAULT(Cos)
DEFAULT(Cosh) DEFAULT(Cosh)
DEFAULT_MULTI(CustomTransforms) DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends) DEFAULT_MULTI(Depends)
DEFAULT(Divide) DEFAULT(Divide)
DEFAULT(NumberOfElements) DEFAULT(NumberOfElements)
@@ -64,7 +68,6 @@ DEFAULT(Full)
DEFAULT(Gather) DEFAULT(Gather)
DEFAULT(Greater) DEFAULT(Greater)
DEFAULT(GreaterEqual) DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less) DEFAULT(Less)
DEFAULT(LessEqual) DEFAULT(LessEqual)
DEFAULT(Load) DEFAULT(Load)
@@ -110,7 +113,6 @@ DEFAULT(Tanh)
DEFAULT(Transpose) DEFAULT(Transpose)
DEFAULT(Inverse) DEFAULT(Inverse)
DEFAULT(Cholesky) DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
namespace { namespace {

View File

@@ -1,117 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void ssyevd(
char jobz,
char uplo,
float* a,
int N,
float* w,
float* work,
int lwork,
int* iwork,
int liwork) {
int info;
MLX_LAPACK_FUNC(ssyevd)
(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ a,
/* lda = */ &N,
/* w = */ w,
/* work = */ work,
/* lwork = */ &lwork,
/* iwork = */ iwork,
/* liwork = */ &liwork,
/* info = */ &info);
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
} // namespace
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
}
auto vec_ptr = vectors.data<float>();
auto eig_ptr = values.data<float>();
char jobz = compute_eigenvectors_ ? 'V' : 'N';
auto N = a.shape(-1);
// Work query
int lwork;
int liwork;
{
float work;
int iwork;
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < a.size() / (N * N); ++i) {
ssyevd(
jobz,
uplo_[0],
vec_ptr,
N,
eig_ptr,
static_cast<float*>(work_buf.buffer.raw_ptr()),
lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
liwork);
vec_ptr += N * N;
eig_ptr += N;
}
}
} // namespace mlx::core

View File

@@ -1,107 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/primitives.h"
namespace mlx::core {
// n = 2^k component
template <typename T>
void hadamard_n(array& out, int n, int m, float scale) {
for (int b = 0; b < out.size() / n; b++) {
size_t loc = b * n;
T* data_ptr = out.data<T>() + loc;
int h = 1;
int n_over_2 = n / 2;
while (h < n) {
for (int i = 0; i < n / 2; i++) {
int k = i & (h - 1);
int j = ((i - k) << 1) + k;
float x = *(data_ptr + j);
float y = *(data_ptr + j + h);
*(data_ptr + j) = x + y;
*(data_ptr + j + h) = x - y;
if (h == n_over_2) {
*(data_ptr + j) *= scale;
*(data_ptr + j + h) *= scale;
}
}
h <<= 1;
}
}
}
// m component
template <typename T>
void hadamard_m(array& out, int n, int m, float scale) {
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
auto start = 1;
auto end = matrix.find('\n', start);
std::vector<bool> hmat_vec;
while (end != std::string_view::npos) {
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
hmat_vec.push_back(row[i] == '+');
}
start = end + 1;
end = matrix.find('\n', start);
}
for (int b = 0; b < out.size() / m / n; b++) {
size_t loc = b * n * m;
T* data_ptr = out.data<T>() + loc;
for (int i = 0; i < n; i++) {
std::vector<float> out(m);
for (int j = 0; j < m; j++) {
for (int k = 0; k < m; k++) {
float x = *(data_ptr + i + k * n);
if (hmat_vec[k + j * m]) {
out[j] += x;
} else {
out[j] -= x;
}
}
}
for (int j = 0; j < m; j++) {
*(data_ptr + i + j * n) = out[j] * scale;
}
}
}
}
template <typename T>
void hadamard(array& out, int n, int m, float scale) {
float n_scale = m > 1 ? 1.0 : scale;
hadamard_n<T>(out, n, m, n_scale);
if (m > 1) {
hadamard_m<T>(out, n, m, scale);
}
}
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Copy input to output
copy(in, out, CopyType::General);
int axis = out.ndim() - 1;
auto [n, m] = decompose_hadamard(out.shape(axis));
switch (in.dtype()) {
case float32:
return hadamard<float>(out, n, m, scale_);
case float16:
return hadamard<float16_t>(out, n, m, scale_);
case bfloat16:
return hadamard<bfloat16_t>(out, n, m, scale_);
default:
throw std::invalid_argument("[hadamard] Unsupported type.");
}
}
} // namespace mlx::core

View File

@@ -1,105 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <map>
#include "mlx/utils.h"
namespace mlx::core {
// From http://neilsloane.com/hadamard/
constexpr std::string_view h12 = R"(
+-++++++++++
--+-+-+-+-+-
+++-++----++
+---+--+-++-
+++++-++----
+-+---+--+-+
++--+++-++--
+--++---+--+
++----+++-++
+--+-++---+-
++++----+++-
+-+--+-++---
)";
constexpr std::string_view h20 = R"(
+----+----++--++-++-
-+----+---+++---+-++
--+----+---+++-+-+-+
---+----+---+++++-+-
----+----++--++-++-+
-+++++-----+--+++--+
+-+++-+---+-+--+++--
++-++--+---+-+--+++-
+++-+---+---+-+--+++
++++-----++--+-+--++
--++-+-++-+-----++++
---++-+-++-+---+-+++
+---++-+-+--+--++-++
++---++-+----+-+++-+
-++---++-+----+++++-
-+--+--++-+----+----
+-+-----++-+----+---
-+-+-+---+--+----+--
--+-+++------+----+-
+--+--++------+----+
)";
constexpr std::string_view h28 = R"(
+------++----++-+--+-+--++--
-+-----+++-----+-+--+-+--++-
--+-----+++---+-+-+----+--++
---+-----+++---+-+-+-+--+--+
----+-----+++---+-+-+++--+--
-----+-----++++--+-+--++--+-
------++----++-+--+-+--++--+
--++++-+-------++--+++-+--+-
---++++-+-----+-++--+-+-+--+
+---+++--+----++-++--+-+-+--
++---++---+----++-++--+-+-+-
+++---+----+----++-++--+-+-+
++++--------+-+--++-++--+-+-
-++++--------+++--++--+--+-+
-+-++-++--++--+--------++++-
+-+-++--+--++--+--------++++
-+-+-++--+--++--+----+---+++
+-+-+-++--+--+---+---++---++
++-+-+-++--+------+--+++---+
-++-+-+-++--+------+-++++---
+-++-+---++--+------+-++++--
-++--++-+-++-+++----++------
+-++--++-+-++-+++-----+-----
++-++---+-+-++-+++-----+----
-++-++-+-+-+-+--+++-----+---
--++-++++-+-+----+++-----+--
+--++-+-++-+-+----+++-----+-
++--++-+-++-+-+----++------+
)";
inline const std::map<int, std::string_view> hadamard_matrices() {
return {{12, h12}, {20, h20}, {28, h28}};
}
inline std::pair<int, int> decompose_hadamard(int n) {
// n = m*2^k
int m = 1;
if (!is_power_of_2(n)) {
auto h_matrices = hadamard_matrices();
for (auto [factor, _] : h_matrices) {
if (n % factor == 0) {
m = factor;
n /= factor;
break;
}
}
if (m == 1) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
}
}
return {n, m};
}
} // namespace mlx::core

View File

@@ -1,4 +1,5 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@@ -80,18 +81,11 @@ void gather(
T* dst_ptr = out.data<T>(); T* dst_ptr = out.data<T>();
size_t out_idx = 0; size_t out_idx = 0;
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> src_it;
if (!can_copy && src.ndim() > 0) {
src_it = std::move(
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
}
for (int idx = 0; idx < ind_size; idx++) { for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0; size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) { for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii]; auto ax = axes[ii];
auto idx_loc = its[ii].loc; auto idx_loc = elem_to_loc(idx, inds[ii]);
its[ii].step();
auto idx_val = auto idx_val =
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax)); offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]); src_idx += (idx_val * src.strides()[ax]);
@@ -105,10 +99,9 @@ void gather(
out_idx += slice_size; out_idx += slice_size;
} else { } else {
for (int jj = 0; jj < slice_size; jj++) { for (int jj = 0; jj < slice_size; jj++) {
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc]; auto src_offset = elem_to_loc(jj, slice_sizes, src.strides());
src_it.step(); dst_ptr[out_idx++] = src_ptr[src_idx + src_offset];
} }
src_it.reset();
} }
} }
} }
@@ -230,29 +223,21 @@ void scatter(
update_size *= us; update_size *= us;
} }
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> update_it(updates);
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) { for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0; size_t out_offset = 0;
for (int j = 0; j < nind; ++j) { for (int j = 0; j < nind; ++j) {
auto ax = axes[j]; auto ax = axes[j];
auto idx_loc = its[j].loc; auto idx_loc = elem_to_loc(i, inds[j]);
its[j].step();
auto idx_val = auto idx_val =
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax)); offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]); out_offset += (idx_val * out.strides()[ax]);
} }
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) { for (int j = 0; j < update_size; ++j) {
op(updates.data<InT>()[update_it.loc], auto update_loc = elem_to_loc(i * update_size + j, updates);
out.data<InT>() + out_offset + out_it.loc); auto out_loc = elem_to_loc(j, update_shape, out.strides());
update_it.step(); op(updates.data<InT>()[update_loc],
out_it.step(); out.data<InT>() + out_offset + out_loc);
} }
out_it.reset();
update_it.reset();
} }
} }

View File

@@ -2,94 +2,17 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
int strtri_wrapper(char uplo, char diag, float* matrix, int N) { #ifdef ACCELERATE_NEW_LAPACK
int info; #include <Accelerate/Accelerate.h>
MLX_LAPACK_FUNC(strtri) #else
( #include <lapack.h>
/* uplo = */ &uplo, #endif
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
return info;
}
namespace mlx::core { namespace mlx::core {
void general_inv(array& inv, int N, int i) { void inverse_impl(const array& a, array& inv) {
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void tri_inv(array& inv, int N, int i, bool upper) {
const char uplo = upper ? 'L' : 'U';
const char diag = 'N';
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: triangular inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
// Lapack uses the column-major convention. We take advantage of the following // Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see // identity to avoid transposing (see
// https://math.stackexchange.com/a/340234): // https://math.stackexchange.com/a/340234):
@@ -101,11 +24,63 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
const int N = a.shape(-1); const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N); const size_t num_matrices = a.size() / (N * N);
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
if (tri) { // Compute LU factorization.
tri_inv(inv, N, i, upper); sgetrf_(
} else { /* m = */ &N,
general_inv(inv, N, i); /* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch =
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
} }
} }
} }
@@ -114,7 +89,7 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) { if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32."); throw std::runtime_error("[Inverse::eval] only supports float32.");
} }
inverse_impl(inputs[0], output, tri_, upper_); inverse_impl(inputs[0], output);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,11 +1,10 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2024 Apple Inc.
#pragma once #pragma once
#ifdef ACCELERATE_NEW_LAPACK #ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h> #include <Accelerate/Accelerate.h>
#else #else
#include <cblas.h>
#include <lapack.h> #include <lapack.h>
#endif #endif

View File

@@ -5,9 +5,11 @@
#include <utility> #include <utility>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/load.h" #include "mlx/io/load.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core {
namespace { namespace {
template <const uint8_t scalar_size> template <const uint8_t scalar_size>
@@ -27,14 +29,12 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
} // namespace } // namespace
namespace mlx::core { void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
void load( reader_->seek(offset_, std::ios_base::beg);
array& out, reader_->read(out.data<char>(), out.nbytes());
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianness_) {
reader->read(out.data<char>(), out.nbytes(), offset);
if (swap_endianness_) { if (swap_endianness_) {
switch (out.itemsize()) { switch (out.itemsize()) {
@@ -51,11 +51,4 @@ void load(
} }
} }
void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,14 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/array.h"
#include "mlx/io/load.h"
namespace mlx::core {
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianess);
} // namespace mlx::core

View File

@@ -18,12 +18,10 @@ if [ "$CLANG" = "TRUE" ]; then
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
EOM EOM
CC_FLAGS=""
else
CC_FLAGS="-std=c++17"
fi fi
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null) CONTENT=$($GCC -I $SRCDIR -E $SRCDIR/mlx/backend/common/compiled_preamble.h 2>/dev/null)
cat << EOF > "$OUTPUT_FILE" cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() { const char* get_kernel_preamble() {

View File

@@ -1,10 +1,15 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring> #include <cstring>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@@ -295,13 +295,6 @@ struct Floor {
} }
}; };
struct Imag {
template <typename T>
T operator()(T x) {
return std::imag(x);
}
};
struct Log { struct Log {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
@@ -344,13 +337,6 @@ struct Negative {
} }
}; };
struct Real {
template <typename T>
T operator()(T x) {
return std::real(x);
}
};
struct Round { struct Round {
template <typename T> template <typename T>
T operator()(T x) { T operator()(T x) {
@@ -387,10 +373,6 @@ struct Sign {
uint64_t operator()(uint64_t x) { uint64_t operator()(uint64_t x) {
return x != 0; return x != 0;
} }
complex64_t operator()(complex64_t x) {
return x == complex64_t(0) ? x : x / std::abs(x);
}
}; };
struct Sin { struct Sin {

View File

@@ -273,10 +273,6 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, ctype); copy(in, out, ctype);
} }
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval(const std::vector<array>& inputs, array& out) { void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
@@ -402,10 +398,6 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) { void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
@@ -413,8 +405,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out); auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) { if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
copy_inplace(in, out, CopyType::General);
} else { } else {
shared_buffer_reshape(in, out_strides, out); shared_buffer_reshape(in, out_strides, out);
} }
@@ -504,16 +495,8 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
/* int64_t o_offset = */ 0, /* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General); /* CopyType ctype = */ CopyType::General);
} else { } else {
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()}; std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out); shared_buffer_slice(in, ostrides, data_offset, out);
} }
} }
@@ -611,18 +594,11 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
strides[i] /= obytes; strides[i] /= obytes;
} }
out.copy_shared_buffer( out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes); in, strides, in.flags(), in.data_size() * obytes / ibytes);
} else { } else {
auto tmp = array( auto tmp = array(in.shape(), in.dtype(), nullptr, {});
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
if (in.dtype() == bool_) { copy_inplace(in, tmp, CopyType::General);
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_inplace(in_tmp, tmp, CopyType::General);
} else {
copy_inplace(in, tmp, CopyType::General);
}
auto flags = out.flags(); auto flags = out.flags();
flags.contiguous = true; flags.contiguous = true;

View File

@@ -2,9 +2,14 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core { namespace mlx::core {
template <typename T> template <typename T>

View File

@@ -201,61 +201,55 @@ void _qmm_dispatch(
int group_size, int group_size,
bool transposed_w) { bool transposed_w) {
int K = x.shape(-1); int K = x.shape(-1);
int M = x.shape(-2); int M = x.size() / K;
int N = out.shape(-1); int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; switch (x.dtype()) {
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; case float32:
_qmm_dispatch_typed<float>(
int batch_size = x.size() / x.shape(-1) / x.shape(-2); out.data<float>(),
for (int i = 0; i < batch_size; i++) { x.data<float>(),
switch (x.dtype()) { w.data<uint32_t>(),
case float32: scales.data<float>(),
_qmm_dispatch_typed<float>( biases.data<float>(),
out.data<float>() + i * M * N, M,
x.data<float>() + elem_to_loc(i * M * K, x), N,
w.data<uint32_t>() + elem_to_loc(i * w_els, w), K,
scales.data<float>() + elem_to_loc(i * g_els, scales), bits,
biases.data<float>() + elem_to_loc(i * g_els, biases), group_size,
M, transposed_w);
N, break;
K, case float16:
bits, _qmm_dispatch_typed<float16_t>(
group_size, out.data<float16_t>(),
transposed_w); x.data<float16_t>(),
break; w.data<uint32_t>(),
case float16: scales.data<float16_t>(),
_qmm_dispatch_typed<float16_t>( biases.data<float16_t>(),
out.data<float16_t>() + i * M * N, M,
x.data<float16_t>() + elem_to_loc(i * M * K, x), N,
w.data<uint32_t>() + elem_to_loc(i * w_els, w), K,
scales.data<float16_t>() + elem_to_loc(i * g_els, scales), bits,
biases.data<float16_t>() + elem_to_loc(i * g_els, biases), group_size,
M, transposed_w);
N, break;
K, case bfloat16:
bits, _qmm_dispatch_typed<bfloat16_t>(
group_size, out.data<bfloat16_t>(),
transposed_w); x.data<bfloat16_t>(),
break; w.data<uint32_t>(),
case bfloat16: scales.data<bfloat16_t>(),
_qmm_dispatch_typed<bfloat16_t>( biases.data<bfloat16_t>(),
out.data<bfloat16_t>() + i * M * N, M,
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x), N,
w.data<uint32_t>() + elem_to_loc(i * w_els, w), K,
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales), bits,
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases), group_size,
M, transposed_w);
N, break;
K, default:
bits, throw std::invalid_argument(
group_size, "[quantized_matmul] only floating types are supported");
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
} }
} }

View File

@@ -87,38 +87,6 @@ struct OrReduce {
} }
}; };
struct MaxReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y > x) ? *y : x;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y > x) ? *y : x;
}
};
};
struct MinReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y < x) ? *y : x;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y < x) ? *y : x;
}
};
};
template <typename InT> template <typename InT>
void reduce_dispatch_out( void reduce_dispatch_out(
const array& in, const array& in,
@@ -150,13 +118,15 @@ void reduce_dispatch_out(
break; break;
} }
case Reduce::Max: { case Reduce::Max: {
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
auto init = Limits<InT>::min; auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce()); reduction_op<InT, InT>(in, out, axes, init, op);
break; break;
} }
case Reduce::Min: { case Reduce::Min: {
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
auto init = Limits<InT>::max; auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce()); reduction_op<InT, InT>(in, out, axes, init, op);
break; break;
} }
} }

View File

@@ -49,7 +49,7 @@ struct ReductionPlan {
ReductionPlan(ReductionOpType type_) : type(type_) {} ReductionPlan(ReductionOpType type_) : type(type_) {}
}; };
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes); ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
// Helper for the ndimensional strided loop // Helper for the ndimensional strided loop
// Should this be in utils? // Should this be in utils?

View File

@@ -19,7 +19,7 @@ std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
return std::make_pair(shape, strides); return std::make_pair(shape, strides);
} }
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) { ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything // The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() && if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) { x.flags().contiguous) {
@@ -32,7 +32,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
std::vector<int> shape = {x.shape(axes[0])}; std::vector<int> shape = {x.shape(axes[0])};
std::vector<size_t> strides = {x.strides()[axes[0]]}; std::vector<size_t> strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) { for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) { if (axes[i] - 1 == axes[i - 1]) {
shape.back() *= x.shape(axes[i]); shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]]; strides.back() = x.strides()[axes[i]];
} else { } else {
@@ -41,14 +41,6 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
} }
} }
// Remove singleton axes from the plan
for (int i = shape.size() - 1; i >= 0; i--) {
if (shape[i] == 1) {
shape.erase(shape.begin() + i);
strides.erase(strides.begin() + i);
}
}
if (strides.back() == 1) { if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides); return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) { } else if (strides.back() > 1) {
@@ -71,14 +63,10 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// have a contiguous reduction. // have a contiguous reduction.
std::vector<std::pair<int, size_t>> reductions; std::vector<std::pair<int, size_t>> reductions;
for (auto a : axes) { for (auto a : axes) {
if (x.shape(a) > 1) { reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
} }
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) { std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
bool a_is_zero = a.second == 0; return a.second > b.second;
bool b_is_zero = b.second == 0;
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
}); });
// Extract the two smallest and try to merge them in case the contiguous // Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis. // reduction can be bigger than just the last axis.
@@ -110,33 +98,16 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// strides.back() are contiguous. // strides.back() are contiguous.
if (strides.back() > 1) { if (strides.back() > 1) {
int size = 1; int size = 1;
bool have_expand = false;
for (int i = x.ndim() - 1; i >= 0; i--) { for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) { if (axes.back() == i) {
continue; continue;
} }
if (x.strides()[i] != size) {
size_t stride_i = x.strides()[i];
int shape_i = x.shape(i);
if (stride_i == 0) {
if (shape_i == 1) {
continue;
}
have_expand = true;
break; break;
} }
size *= x.shape(i);
if (stride_i != size && shape_i != 1) {
break;
}
size *= shape_i;
} }
// In the case of an expanded dimension we are being conservative and if (size >= strides.back()) {
// require the smallest reduction stride to be smaller than the maximum row
// contiguous size. The reason is that we can't easily know if the reduced
// axis is before or after an expanded dimension.
if (size > strides.back() || (size == strides.back() && !have_expand)) {
return ReductionPlan(GeneralStridedReduce, shape, strides); return ReductionPlan(GeneralStridedReduce, shape, strides);
} }
} }

View File

@@ -6,16 +6,18 @@ namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice( std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in, const array& in,
const std::vector<int>& start_indices, std::vector<int>& start_indices,
const std::vector<int>& strides) { std::vector<int>& strides) {
int64_t data_offset = 0; int64_t data_offset = 0;
bool copy_needed = false; bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0); std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) { for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices[i] * in.strides()[i]; data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i]; inp_strides[i] = in.strides()[i] * strides[i];
copy_needed |= strides[i] < 0; copy_needed |= strides[i] < 0;
} }
return std::make_tuple(copy_needed, data_offset, inp_strides); return std::make_tuple(copy_needed, data_offset, inp_strides);
} }
@@ -23,16 +25,26 @@ void shared_buffer_slice(
const array& in, const array& in,
const std::vector<size_t>& out_strides, const std::vector<size_t>& out_strides,
size_t data_offset, size_t data_offset,
size_t data_size,
array& out) { array& out) {
// Compute row/col contiguity // Compute row/col contiguity
auto [no_bsx_size, is_row_contiguous, is_col_contiguous] = auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides); check_contiguity(out.shape(), out_strides);
auto flags = in.flags(); auto flags = in.flags();
flags.row_contiguous = is_row_contiguous; flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous; flags.col_contiguous = is_col_contiguous;
flags.contiguous = (no_bsx_size == data_size);
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset); out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
} }

View File

@@ -8,14 +8,13 @@ namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice( std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in, const array& in,
const std::vector<int>& start_indices, std::vector<int>& start_indices,
const std::vector<int>& strides); std::vector<int>& strides);
void shared_buffer_slice( void shared_buffer_slice(
const array& in, const array& in,
const std::vector<size_t>& out_strides, const std::vector<size_t>& out_strides,
size_t data_offset, size_t data_offset,
size_t data_size,
array& out); array& out);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -111,29 +111,26 @@ void sort(const array& in, array& out, int axis) {
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); size_t n_rows = in.size() / in.shape(axis);
size_t n_rows = in_size / in.shape(axis);
auto remaining_shape = out.shape(); auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = out.strides(); auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis); remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = out.strides()[axis]; size_t axis_stride = in.strides()[axis];
int axis_size = out.shape(axis); int axis_size = in.shape(axis);
// Perform sorting in place // Perform sorting in place
ContiguousIterator<size_t> src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc; size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
T* data_ptr = out.data<T>() + loc;
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed); std::stable_sort(st, ed);
src_it.step();
} }
} }
@@ -146,46 +143,34 @@ void argsort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape(); auto remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
auto in_remaining_strides = in.strides(); auto remaining_strides = in.strides();
in_remaining_strides.erase(in_remaining_strides.begin() + axis); remaining_strides.erase(remaining_strides.begin() + axis);
auto out_remaining_shape = out.shape(); size_t axis_stride = in.strides()[axis];
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis); int axis_size = in.shape(axis);
// Perform sorting // Perform sorting
ContiguousIterator<size_t> in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc; size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc; const T* data_ptr = in.data<T>() + loc;
in_it.step(); IdxT* idx_ptr = out.data<IdxT>() + loc;
out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0); StridedIterator st_(idx_ptr, axis_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size); StridedIterator ed_(idx_ptr, axis_stride, axis_size);
// Initialize with iota // Initialize with iota
std::iota(st_, ed_, IdxT(0)); std::iota(st_, ed_, IdxT(0));
// Sort according to vals // Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0); StridedIterator st(idx_ptr, axis_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size); StridedIterator ed(idx_ptr, axis_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * axis_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * axis_stride];
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }
@@ -199,8 +184,7 @@ void partition(const array& in, array& out, int axis, int kth) {
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); size_t n_rows = in.size() / in.shape(axis);
size_t n_rows = in_size / in.shape(axis);
auto remaining_shape = in.shape(); auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
@@ -214,11 +198,9 @@ void partition(const array& in, array& out, int axis, int kth) {
kth = kth < 0 ? kth + axis_size : kth; kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place // Perform partition in place
ContiguousIterator<size_t> src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc; size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
src_it.step(); T* data_ptr = out.data<T>() + loc;
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth); StridedIterator md(data_ptr, axis_stride, kth);
@@ -237,49 +219,37 @@ void argpartition(const array& in, array& out, int axis, int kth) {
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape(); auto remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
auto in_remaining_strides = in.strides(); auto remaining_strides = in.strides();
in_remaining_strides.erase(in_remaining_strides.begin() + axis); remaining_strides.erase(remaining_strides.begin() + axis);
auto out_remaining_shape = out.shape(); size_t axis_stride = in.strides()[axis];
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis); int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth; kth = kth < 0 ? kth + axis_size : kth;
// Perform partition // Perform partition
ContiguousIterator<size_t> in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc; size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc; const T* data_ptr = in.data<T>() + loc;
in_it.step(); IdxT* idx_ptr = out.data<IdxT>() + loc;
out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0); StridedIterator st_(idx_ptr, axis_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size); StridedIterator ed_(idx_ptr, axis_stride, axis_size);
// Initialize with iota // Initialize with iota
std::iota(st_, ed_, IdxT(0)); std::iota(st_, ed_, IdxT(0));
// Sort according to vals // Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0); StridedIterator st(idx_ptr, axis_stride, 0);
StridedIterator md(idx_ptr, out_stride, kth); StridedIterator md(idx_ptr, axis_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size); StridedIterator ed(idx_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * axis_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * axis_stride];
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }

View File

@@ -2,7 +2,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h" #include "mlx/backend/common/lapack_helper.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {

View File

@@ -12,7 +12,6 @@ namespace {
// TODO: Add support for more combinations of input types. // TODO: Add support for more combinations of input types.
enum class TernaryOpType { enum class TernaryOpType {
ScalarScalarScalar, ScalarScalarScalar,
VectorVectorVector,
General, General,
}; };
@@ -21,12 +20,6 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
TernaryOpType topt; TernaryOpType topt;
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) { if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
topt = TernaryOpType::ScalarScalarScalar; topt = TernaryOpType::ScalarScalarScalar;
} else if (
(a.flags().row_contiguous && b.flags().row_contiguous &&
c.flags().row_contiguous) ||
(a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector;
} else { } else {
topt = TernaryOpType::General; topt = TernaryOpType::General;
} }
@@ -40,77 +33,138 @@ void set_ternary_op_output_data(
array& out, array& out,
TernaryOpType topt, TernaryOpType topt,
bool donate_with_move = false) { bool donate_with_move = false) {
auto maybe_donate = [&out, donate_with_move](const array& x) {
if (is_donatable(x, out)) {
if (donate_with_move) {
out.move_shared_buffer(x);
} else {
out.copy_shared_buffer(x);
}
return true;
}
return false;
};
switch (topt) { switch (topt) {
case TernaryOpType::ScalarScalarScalar: case TernaryOpType::ScalarScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags()); allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
break; break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
}
break;
case TernaryOpType::General: case TernaryOpType::General:
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
break; break;
} }
} }
template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
void ternary_op_dims(
const T1* a,
const T2* b,
const T3* c,
U* out,
Op op,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& c_strides,
const std::vector<size_t>& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_c = c_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
for (int i = 0; i < N; i++) { template <typename T1, typename T2, typename T3, typename U, typename Op>
if constexpr (D > 1) { void ternary_op_dims1(
ternary_op_dims<T1, T2, T3, U, Op, D - 1>( const array& a,
a, const array& b,
b, const array& c,
c, array& out,
out, Op op) {
op, const T1* a_ptr = a.data<T1>();
shape, const T2* b_ptr = b.data<T2>();
a_strides, const T3* c_ptr = c.data<T3>();
b_strides,
c_strides, U* dst = out.data<U>();
out_strides, size_t a_idx = 0;
axis + 1); size_t b_idx = 0;
} else { size_t c_idx = 0;
*out = op(*a, *b, *c); for (size_t i = 0; i < out.size(); ++i) {
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
c_idx += c.strides()[0];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims2(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
c_idx += c.strides()[1];
} }
a += stride_a; a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b += stride_b; b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c += stride_c; c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
out += stride_out; }
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims3(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
c_idx += c.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims4(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
c_idx += c.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
} }
} }
@@ -121,69 +175,30 @@ void ternary_op_dispatch_dims(
const array& c, const array& c,
array& out, array& out,
Op op) { Op op) {
auto [shape, strides] = collapse_contiguous_dims( switch (out.ndim()) {
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); case 1:
const auto& a_strides = strides[0]; ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
const auto& b_strides = strides[1]; return;
const auto& c_strides = strides[2]; case 2:
const auto& out_strides = strides[3]; ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 3:
ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 4:
ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
}
const T1* a_ptr = a.data<T1>(); const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>(); const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>(); const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<T3>(); U* dst = out.data<U>();
int ndim = shape.size(); for (size_t i = 0; i < out.size(); i++) {
switch (ndim) { int a_idx = elem_to_loc(i, a.shape(), a.strides());
case 1: int b_idx = elem_to_loc(i, b.shape(), b.strides());
ternary_op_dims<T1, T2, T3, U, Op, 1>( int c_idx = elem_to_loc(i, c.shape(), c.strides());
a_ptr, dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
case 2:
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
}
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
size_t stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
c_ptr + c_it.loc,
out_ptr + elem,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
c_it.step();
} }
} }
@@ -200,21 +215,10 @@ void ternary_op(
// The full computation is scalar-scalar-scalar so we call the base op once. // The full computation is scalar-scalar-scalar so we call the base op once.
if (topt == TernaryOpType::ScalarScalarScalar) { if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>()); *(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
} else if (topt == TernaryOpType::VectorVectorVector) { return;
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
} else {
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
} }
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
} }
} // namespace } // namespace

View File

@@ -12,7 +12,7 @@ namespace mlx::core {
namespace { namespace {
void set_unary_output_data(const array& in, array& out) { void set_unary_output_data(const array& in, array& out) {
if (is_donatable(in, out)) { if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
auto size = in.data_size(); auto size = in.data_size();
@@ -24,36 +24,22 @@ void set_unary_output_data(const array& in, array& out) {
} }
} }
template <typename T, typename U = T, typename Op> template <typename T, typename Op>
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = op(*a);
a += stride;
}
}
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op op) { void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>(); const T* a_ptr = a.data<T>();
if (a.flags().contiguous) { if (a.flags().contiguous) {
set_unary_output_data(a, out); set_unary_output_data(a, out);
U* dst = out.data<U>(); T* dst = out.data<T>();
for (size_t i = 0; i < a.data_size(); ++i) { for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]); dst[i] = op(a_ptr[i]);
} }
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
U* dst = out.data<U>(); T* dst = out.data<T>();
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1; for (size_t i = 0; i < out.size(); ++i) {
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1; // TODO this is super inefficient, need to fix.
if (a.ndim() <= 1) { int a_idx = elem_to_loc(i, a.shape(), a.strides());
unary_op(a_ptr, dst, op, shape, stride); dst[i] = op(a_ptr[a_idx]);
return;
}
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
it.step();
} }
} }
} }

View File

@@ -1,138 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
template <typename StrideT>
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<std::vector<StrideT>>& strides,
StrideT size_cap) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
if (shape[0] != 1) {
to_collapse.push_back(0);
}
size_t size = shape[0];
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
size *= shape[i];
for (const std::vector<StrideT>& st : strides) {
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
contiguous = false;
size = shape[i];
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
if (shape[i] != 1) {
to_collapse.push_back(i);
}
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<StrideT>> out_strides(strides.size());
for (int i = 0;;) {
while (i < to_collapse.size() && to_collapse[i] == -1) {
++i;
};
if (i == to_collapse.size()) {
break;
}
int current_shape = shape[to_collapse[i]];
int k = i;
while (to_collapse[++k] != -1) {
current_shape *= shape[to_collapse[k]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<StrideT>& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]);
}
i = k + 1;
}
if (!shape.empty() && out_shape.empty()) {
out_shape.push_back(1);
for (auto& out_stride : out_strides) {
out_stride.push_back(0);
}
}
return std::make_tuple(out_shape, out_strides);
}
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<int64_t>>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
template <typename StrideT>
std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
StrideT size_cap) {
std::vector<int> collapsed_shape;
std::vector<StrideT> collapsed_strides;
if (shape.size() > 0) {
collapsed_shape.push_back(shape[0]);
collapsed_strides.push_back(strides[0]);
for (int i = 1; i < shape.size(); i++) {
if (shape[i] == 1) {
continue;
} else if (
strides[i] * shape[i] != collapsed_strides.back() ||
collapsed_shape.back() * static_cast<StrideT>(shape[i]) > size_cap) {
collapsed_shape.push_back(shape[i]);
collapsed_strides.push_back(strides[i]);
} else {
collapsed_shape.back() *= shape[i];
collapsed_strides.back() = strides[i];
}
}
}
return std::make_pair(collapsed_shape, collapsed_strides);
}
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<int64_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<size_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const array& a,
size_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
return collapse_contiguous_dims_impl<size_t>(
a.shape(), a.strides(), size_cap);
}
} // namespace mlx::core

View File

@@ -8,12 +8,12 @@
namespace mlx::core { namespace mlx::core {
template <typename StrideT> template <typename stride_t>
inline StrideT elem_to_loc( inline stride_t elem_to_loc(
int elem, int elem,
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<StrideT>& strides) { const std::vector<stride_t>& strides) {
StrideT loc = 0; stride_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) { for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]); auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i]; loc += q_and_r.rem * strides[i];
@@ -29,41 +29,64 @@ inline size_t elem_to_loc(int elem, const array& a) {
return elem_to_loc(elem, a.shape(), a.strides()); return elem_to_loc(elem, a.shape(), a.strides());
} }
template <typename StrideT>
std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
std::vector<StrideT> strides(shape.size(), 1);
for (int i = shape.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * shape[i];
}
return strides;
}
// Collapse dims that are contiguous to possibly route to a better kernel // Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1}) // e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}. // should return {{2, 4}, {{1, 2}}}.
// //
// When multiple arrays are passed they should all have the same shape. The // When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned. // collapsed axes are also the same so one shape is returned.
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>> template <typename stride_t>
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
collapse_contiguous_dims( collapse_contiguous_dims(
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<std::vector<int64_t>>& strides, const std::vector<std::vector<stride_t>> strides) {
int64_t size_cap = std::numeric_limits<int32_t>::max()); // Make a vector that has axes separated with -1. Collapse all axes between
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>> // -1.
collapse_contiguous_dims( std::vector<int> to_collapse;
const std::vector<int>& shape, if (shape.size() > 0) {
const std::vector<std::vector<size_t>>& strides, to_collapse.push_back(0);
size_t size_cap = std::numeric_limits<int32_t>::max()); for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
for (const std::vector<stride_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<stride_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<stride_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>> inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims( collapse_contiguous_dims(const std::vector<array>& xs) {
const std::vector<array>& xs,
size_t size_cap = std::numeric_limits<int32_t>::max()) {
std::vector<std::vector<size_t>> strides; std::vector<std::vector<size_t>> strides;
for (auto& x : xs) { for (auto& x : xs) {
strides.emplace_back(x.strides()); strides.emplace_back(x.strides());
} }
return collapse_contiguous_dims(xs[0].shape(), strides, size_cap); return collapse_contiguous_dims(xs[0].shape(), strides);
} }
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>> template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
@@ -72,110 +95,27 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
std::vector<array>{std::forward<Arrays>(xs)...}); std::vector<array>{std::forward<Arrays>(xs)...});
} }
// The single array version of the above. template <typename stride_t>
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const array& a,
size_t size_cap = std::numeric_limits<int32_t>::max());
template <typename StrideT>
struct ContiguousIterator {
inline void step() {
int dims = shape_.size();
if (dims == 0) {
return;
}
int i = dims - 1;
while (pos_[i] == (shape_[i] - 1) && i > 0) {
pos_[i] = 0;
loc -= (shape_[i] - 1) * strides_[i];
i--;
}
pos_[i]++;
loc += strides_[i];
}
void seek(StrideT n) {
loc = 0;
for (int i = shape_.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(n, shape_[i]);
loc += q_and_r.rem * strides_[i];
pos_[i] = q_and_r.rem;
n = q_and_r.quot;
}
}
void reset() {
loc = 0;
std::fill(pos_.begin(), pos_.end(), 0);
}
ContiguousIterator() {};
explicit ContiguousIterator(const array& a)
: shape_(a.shape()), strides_(a.strides()) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
}
}
explicit ContiguousIterator(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
int dims)
: shape_(shape.begin(), shape.begin() + dims),
strides_(strides.begin(), strides.begin() + dims) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
}
}
StrideT loc{0};
private:
std::vector<int> shape_;
std::vector<StrideT> strides_;
std::vector<int> pos_;
};
template <typename StrideT>
inline auto check_contiguity( inline auto check_contiguity(
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<StrideT>& strides) { const std::vector<stride_t>& strides) {
size_t no_broadcast_data_size = 1; size_t data_size = 1;
size_t f_stride = 1; size_t f_stride = 1;
size_t b_stride = 1; size_t b_stride = 1;
bool is_row_contiguous = true; bool is_row_contiguous = true;
bool is_col_contiguous = true; bool is_col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) { for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
is_col_contiguous &= strides[i] == f_stride || shape[i] == 1; is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1; is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i]; f_stride *= shape[i];
b_stride *= shape[ri]; b_stride *= shape[ri];
if (strides[i] > 0) { if (strides[i] > 0) {
no_broadcast_data_size *= shape[i]; data_size *= shape[i];
} }
} }
return std::make_tuple( return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
}
inline bool is_donatable(const array& in, const array& out) {
constexpr size_t donation_extra = 16384;
return in.is_donatable() && in.itemsize() == out.itemsize() &&
in.buffer_size() <= out.nbytes() + donation_extra;
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,56 +1,97 @@
function(make_jit_source SRC_FILE) function(make_jit_source SRC_FILE)
# This function takes a metal header file, runs the C preprocessesor on it, # This function takes a metal header file,
# and makes the processed contents available as a string in a C++ function # runs the C preprocessesor on it, and makes
# the processed contents available as a string in a C++ function
# mlx::core::metal::${SRC_NAME}() # mlx::core::metal::${SRC_NAME}()
# #
# To use the function, declare it in jit/includes.h and include # To use the function, declare it in jit/includes.h and
# jit/includes.h. # include jit/includes.h.
# #
# Additional arguments to this function are treated as dependencies in the # Additional arguments to this function are treated as dependencies
# Cmake build system. # in the Cmake build system.
get_filename_component(SRC_NAME ${SRC_FILE} NAME) get_filename_component(SRC_NAME ${SRC_FILE} NAME)
add_custom_command( add_custom_command(
OUTPUT jit/${SRC_NAME}.cpp OUTPUT jit/${SRC_NAME}.cpp
COMMAND COMMAND /bin/bash
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/jit
${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}" ${CMAKE_C_COMPILER}
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN}) ${PROJECT_SOURCE_DIR}
${SRC_FILE}
"-D${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/${SRC_FILE}.h
${ARGN}
)
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp) add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME}) add_dependencies(mlx ${SRC_NAME})
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp) target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
)
endfunction(make_jit_source) endfunction(make_jit_source)
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h) make_jit_source(
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h) utils
kernels/bf16.h
kernels/complex.h
kernels/defines.h
)
make_jit_source(
unary_ops
kernels/erf.h
kernels/expm1f.h
)
make_jit_source(binary_ops) make_jit_source(binary_ops)
make_jit_source(ternary_ops) make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h) make_jit_source(
make_jit_source(scatter kernels/indexing.h) reduce_utils
make_jit_source(gather kernels/indexing.h) kernels/atomic.h
make_jit_source(hadamard) kernels/reduction/ops.h
)
make_jit_source(scatter)
make_jit_source(gather)
if(MLX_METAL_JIT) if (MLX_METAL_JIT)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp) target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
)
make_jit_source(arange) make_jit_source(arange)
make_jit_source(copy) make_jit_source(copy)
make_jit_source(unary) make_jit_source(unary)
make_jit_source(binary) make_jit_source(binary)
make_jit_source(binary_two) make_jit_source(binary_two)
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h) make_jit_source(
fft
kernels/fft/radix.h
kernels/fft/readwrite.h
)
make_jit_source(ternary) make_jit_source(ternary)
make_jit_source(softmax) make_jit_source(softmax)
make_jit_source(scan) make_jit_source(scan)
make_jit_source(sort) make_jit_source(sort)
make_jit_source( make_jit_source(
reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h reduce
kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h) kernels/reduction/reduce_all.h
kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h
)
make_jit_source( make_jit_source(
steel/gemm/gemm kernels/steel/utils.h kernels/steel/gemm/loader.h steel/gemm/gemm
kernels/steel/gemm/mma.h kernels/steel/gemm/params.h kernels/steel/utils.h
kernels/steel/gemm/transforms.h) kernels/steel/gemm/loader.h
kernels/steel/gemm/mma.h
kernels/steel/gemm/params.h
kernels/steel/gemm/transforms.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_fused) make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h) make_jit_source(
steel/gemm/kernels/steel_gemm_masked
kernels/steel/defines.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk) make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source( make_jit_source(
steel/conv/conv steel/conv/conv
@@ -61,52 +102,58 @@ if(MLX_METAL_JIT)
kernels/steel/conv/params.h kernels/steel/conv/params.h
kernels/steel/conv/loader.h kernels/steel/conv/loader.h
kernels/steel/conv/loaders/loader_channel_l.h kernels/steel/conv/loaders/loader_channel_l.h
kernels/steel/conv/loaders/loader_channel_n.h) kernels/steel/conv/loaders/loader_channel_n.h
make_jit_source(steel/conv/kernels/steel_conv) )
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h make_jit_source(
kernels/steel/conv/loaders/loader_general.h) steel/conv/kernels/steel_conv
)
make_jit_source(
steel/conv/kernels/steel_conv_general
kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h
)
make_jit_source(quantized) make_jit_source(quantized)
make_jit_source(gemv_masked)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp) target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
)
endif() endif()
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp )
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
if(NOT MLX_METAL_PATH) if (NOT MLX_METAL_PATH)
set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/) set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
endif() endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
target_compile_definitions(mlx target_compile_definitions(
PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib") mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")

View File

@@ -2,7 +2,6 @@
#include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include <mach/vm_page_size.h> #include <mach/vm_page_size.h>
#include <unistd.h> #include <unistd.h>
@@ -141,7 +140,6 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator() MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()), : device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(device_) { buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]); auto memsize = std::get<size_t>(device_info()["memory_size"]);
block_limit_ = block_limit_ =
@@ -150,8 +148,6 @@ MetalAllocator::MetalAllocator()
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()), static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
block_limit_); block_limit_);
max_pool_size_ = block_limit_; max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
} }
size_t MetalAllocator::set_cache_limit(size_t limit) { size_t MetalAllocator::set_cache_limit(size_t limit) {
@@ -168,12 +164,6 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
return limit; return limit;
}; };
size_t MetalAllocator::set_wired_limit(size_t limit) {
std::swap(limit, wired_limit_);
residency_set_.resize(wired_limit_);
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Metal doesn't like empty buffers // Metal doesn't like empty buffers
if (size == 0) { if (size == 0) {
@@ -215,7 +205,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Allocate new buffer if needed // Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared; size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeUntracked; res_opt |= MTL::ResourceHazardTrackingModeTracked;
lk.unlock(); lk.unlock();
buf = device_->newBuffer(size, res_opt); buf = device_->newBuffer(size, res_opt);
lk.lock(); lk.lock();
@@ -230,8 +220,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
} }
residency_set_.insert(buf);
return Buffer{static_cast<void*>(buf)}; return Buffer{static_cast<void*>(buf)};
} }
@@ -243,7 +231,6 @@ void MetalAllocator::clear_cache() {
void MetalAllocator::free(Buffer buffer) { void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr()); auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
residency_set_.erase(buf);
active_memory_ -= buf->length(); active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) { if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
@@ -254,16 +241,9 @@ void MetalAllocator::free(Buffer buffer) {
} }
} }
size_t MetalAllocator::size(Buffer buffer) const {
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
}
MetalAllocator& allocator() { MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator static MetalAllocator allocator_;
// will not be called on exit and buffers in the cache will be leaked. This return allocator_;
// can save some time at program exit.
static MetalAllocator* allocator_ = new MetalAllocator;
return *allocator_;
} }
size_t set_cache_limit(size_t limit) { size_t set_cache_limit(size_t limit) {
@@ -272,15 +252,6 @@ size_t set_cache_limit(size_t limit) {
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) { size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed); return allocator().set_memory_limit(limit, relaxed);
} }
size_t set_wired_limit(size_t limit) {
if (limit >
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");
}
return allocator().set_wired_limit(limit);
}
size_t get_active_memory() { size_t get_active_memory() {
return allocator().get_active_memory(); return allocator().get_active_memory();
} }

View File

@@ -8,7 +8,6 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/resident.h"
namespace mlx::core::metal { namespace mlx::core::metal {
@@ -57,7 +56,6 @@ class MetalAllocator : public allocator::Allocator {
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() { size_t get_active_memory() {
return active_memory_; return active_memory_;
}; };
@@ -73,7 +71,6 @@ class MetalAllocator : public allocator::Allocator {
}; };
size_t set_cache_limit(size_t limit); size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed); size_t set_memory_limit(size_t limit, bool relaxed);
size_t set_wired_limit(size_t limit);
void clear_cache(); void clear_cache();
private: private:
@@ -84,15 +81,12 @@ class MetalAllocator : public allocator::Allocator {
// Caching allocator // Caching allocator
BufferCache buffer_cache_; BufferCache buffer_cache_;
ResidencySet residency_set_;
// Allocation stats // Allocation stats
size_t block_limit_; size_t block_limit_;
size_t gc_limit_; size_t gc_limit_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
size_t max_pool_size_; size_t max_pool_size_;
size_t wired_limit_{0};
bool relaxed_{true}; bool relaxed_{true};
std::mutex mutex_; std::mutex mutex_;

View File

@@ -1,4 +1,5 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
@@ -18,47 +19,12 @@
namespace mlx::core { namespace mlx::core {
std::string get_kernel_name( constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
BinaryOpType bopt,
const std::string& op,
const array& a,
bool use_2d,
int ndim,
int work_per_thread) {
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << (use_2d ? "sv2" : "sv");
break;
case BinaryOpType::VectorScalar:
kname << (use_2d ? "vs2" : "vs");
break;
case BinaryOpType::VectorVector:
kname << (use_2d ? "vv2" : "vv");
break;
case BinaryOpType::General:
kname << "g";
if (ndim <= 3) {
kname << ndim;
} else {
kname << "n";
if (work_per_thread > 1) {
kname << work_per_thread;
}
}
break;
}
kname << "_" << op << type_to_name(a);
return kname.str();
}
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string& op, const std::string op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
@@ -70,68 +36,80 @@ void binary_op_gpu_inplace(
} }
// Try to collapse contiguous dims // Try to collapse contiguous dims
auto maybe_collapse = [bopt, &a, &b, &out]() { auto [shape, strides] = collapse_contiguous_dims(a, b, out);
if (bopt == BinaryOpType::General) { auto& strides_a = strides[0];
auto [shape, strides] = collapse_contiguous_dims(a, b, out); auto& strides_b = strides[1];
return std::make_tuple(shape, strides[0], strides[1], strides[2]); auto& strides_out = strides[2];
} else {
std::vector<size_t> e; std::string kernel_name;
return std::make_tuple(std::vector<int>{}, e, e, e); {
} std::ostringstream kname;
}; switch (bopt) {
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse(); case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
bool use_2d = out.data_size() > UINT32_MAX;
auto ndim = shape.size();
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
std::string kernel_name =
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto kernel = outputs.size() == 2 auto kernel =
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op) get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op);
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output // - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated // - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output. // otherwise it goes to the second output
// - If there is only one output only one of a and b will be donated.
bool donate_a = a.data_shared_ptr() == nullptr; bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr; bool donate_b = b.data_shared_ptr() == nullptr;
int arg_idx = 0; compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++);
compute_encoder.set_input_array( compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++); donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
compute_encoder.set_output_array(outputs[0], arg_idx++); compute_encoder.set_output_array(outputs[0], 2);
if (outputs.size() == 2) { compute_encoder.set_output_array(outputs[1], 3);
compute_encoder.set_output_array(outputs[1], arg_idx++);
}
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 7);
}
// Launch up to 3D grid of threads // Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1); size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
}
if (thread_group_size != 1024) { if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
} }
@@ -139,14 +117,14 @@ void binary_op_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} else { } else {
// Launch a 1D or 2D grid of threads // Launch a 1D grid of threads
size_t nthreads = out.data_size(); size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
} }
@@ -154,7 +132,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string& op, const std::string op,
const Stream& s) { const Stream& s) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
@@ -168,7 +146,7 @@ void binary_op_gpu(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string& op) { const std::string op) {
auto& s = outputs[0].primitive().stream(); auto& s = outputs[0].primitive().stream();
binary_op_gpu(inputs, outputs, op, s); binary_op_gpu(inputs, outputs, op, s);
} }
@@ -176,16 +154,106 @@ void binary_op_gpu(
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string& op, const std::string op,
const Stream& s) { const Stream& s) {
std::vector<array> outputs = {out}; auto& a = inputs[0];
binary_op_gpu_inplace(inputs, outputs, op, s); auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::string kernel_name;
{
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_output_array(out, 2);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads =
bopt == BinaryOpType::General ? out.size() : out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} }
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string& op, const std::string op,
const Stream& s) { const Stream& s) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
@@ -198,7 +266,7 @@ void binary_op_gpu(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string& op) { const std::string op) {
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
binary_op_gpu(inputs, out, op, s); binary_op_gpu(inputs, out, op, s);
} }

Some files were not shown because too many files have changed in this diff Show More