Compare commits

..

23 Commits

Author SHA1 Message Date
Angelos Katharopoulos
f5ece1d98d Tmp 2024-01-22 18:38:11 -08:00
Angelos Katharopoulos
1e4d3c7fb2 Add a compile option to nn.value_and_grad 2024-01-21 21:25:01 -08:00
Awni Hannun
db40990d33 clear output tree from cache when function goes out of scope 2024-01-19 20:52:30 -08:00
Awni Hannun
390e50b163 fix multi-output with compile 2024-01-19 20:52:30 -08:00
Awni Hannun
ed4d867092 simplify tests use compile now 2024-01-19 20:52:30 -08:00
Awni Hannun
1c3f82ca17 remove simplify 2024-01-19 20:52:30 -08:00
Awni Hannun
5c78c16f1c enable and disable compiler 2024-01-19 20:52:30 -08:00
Awni Hannun
ba8ce23597 simplify inputs also 2024-01-19 20:52:30 -08:00
Awni Hannun
df1f6c221b bug fix with move function and compile at exit 2024-01-19 20:52:30 -08:00
Awni Hannun
ecfb72157e more cpp tests 2024-01-19 20:52:30 -08:00
Awni Hannun
d34d10ebac simplify 2024-01-19 20:52:30 -08:00
Awni Hannun
b75ff47098 fix python globals bug, and erase 2024-01-19 20:52:30 -08:00
Awni Hannun
6189111494 fix test 2024-01-19 20:52:30 -08:00
Awni Hannun
4f50935c2c compile works with python closures 2024-01-19 20:52:30 -08:00
Awni Hannun
966b7faef4 fix segfault on python exit 2024-01-19 20:52:30 -08:00
Awni Hannun
0005cfe053 basic python tests 2024-01-19 20:52:30 -08:00
Awni Hannun
9739c72781 fix grad, more tests 2024-01-19 20:52:30 -08:00
Awni Hannun
9d70412a91 fix 2024-01-19 20:52:30 -08:00
Awni Hannun
7a753335d7 clean 2024-01-19 20:52:30 -08:00
Awni Hannun
cd1e5b25cc compile binding 2024-01-19 20:52:30 -08:00
Awni Hannun
21062680d5 basic compile scaffold works 2024-01-19 20:52:30 -08:00
Awni Hannun
264e9ad57e make a move on compile 2024-01-19 20:52:30 -08:00
Awni Hannun
c38d43153b fix tests for linux 2024-01-19 20:52:30 -08:00
401 changed files with 15154 additions and 51582 deletions

View File

@@ -1,8 +1,5 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters:
nightly_build:
type: boolean
@@ -10,9 +7,6 @@ parameters:
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
jobs:
linux_build_and_test:
@@ -31,24 +25,24 @@ jobs:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade pybind11[global]
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libblas-dev
- run:
name: Install Python package
name: Build python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
- run:
name: Generate package stubs
name: Run the python tests
command: |
echo "stubs"
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
python3 -m unittest discover python/tests
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3 -m pip install .
- run:
name: Build CPP only
command: |
@@ -58,202 +52,169 @@ jobs:
command: ./build/tests/tests
mac_build_and_test:
parameters:
xcode_version:
type: string
default: "15.2.0"
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
machine: true
resource_class: ml-explore/m-builder
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.8
brew install openmpi
python3.8 -m venv env
source env/bin/activate
pip install --upgrade pip
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=3.9
conda activate runner-env
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade pybind11[global]
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run:
name: Install Python package
name: Build python package
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
eval "$(conda shell.bash hook)"
conda activate runner-env
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
- run:
name: Generate package stubs
name: Run the python tests
command: |
source env/bin/activate
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
- run:
name: Build example extension
command: |
source env/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# eval "$(conda shell.bash hook)"
# conda activate runner-env
# cd examples/extensions && python -m pip install .
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
- run:
name: Build small binary
command: |
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
make -j
command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
build_release:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
xcode_version:
macos_version:
type: string
default: "15.2.0"
build_env:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
default: "14"
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade setuptools
pip install --upgrade pybind11[global]
pip install numpy
pip install twine
pip install build
- run:
name: Install Python package
name: Build package
command: |
source env/bin/activate
DEV_RELEASE=1 \
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
PYPI_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
source env/bin/activate
twine upload dist/*
python setup.py bdist_wheel
twine upload dist/* --repository mlx
- store_artifacts:
path: dist/
build_linux_test_release:
build_dev_release:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
extra_env:
macos_version:
type: string
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
default: "14"
steps:
- checkout
- run:
name: Build wheel
name: Install dependencies
command: |
PYTHON=python<< parameters.python_version >>
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade setuptools
pip install --upgrade pybind11[global]
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
<< parameters.extra_env >> \
pip install twine
- run:
name: Build package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
python setup.py bdist_wheel
twine upload dist/* --repository mlx
- store_artifacts:
path: wheelhouse/
path: dist/
build_package:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
macos_version:
type: string
default: "14"
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install numpy
pip install twine
- run:
name: Build package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel
- store_artifacts:
path: dist/
workflows:
build_and_test:
when:
and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0"]
- linux_build_and_test
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test
- build_release:
filters:
tags:
@@ -263,56 +224,20 @@ workflows:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0"]
- linux_build_and_test:
requires: [ hold ]
macos_version: ["13", "14"]
nightly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
when: << pipeline.parameters.nightly_build >>
jobs:
- build_release:
- build_package:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"]
macos_version: ["13", "14"]
weekly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
when: << pipeline.parameters.weekly_build >>
jobs:
- build_release:
- build_dev_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_linux_test_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
extra_env: ["PYPI_RELEASE=1"]
macos_version: ["13", "14"]

View File

@@ -1,15 +1,15 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.4
rev: v17.0.6
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
rev: 23.12.1
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 5.12.0
hooks:
- id: isort
args:

View File

@@ -7,16 +7,11 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
@@ -257,4 +252,4 @@ Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.

View File

@@ -15,37 +15,32 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.14.0)
set(MLX_VERSION 0.0.10)
endif()
# --------------------- Processor tests -------------------------
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
set(MLX_BUILD_ARM OFF)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC)
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
else()
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
set(MLX_BUILD_METAL OFF)
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message(WARNING
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, "
" make sure you are building for arm64.")
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
@@ -70,13 +65,9 @@ endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if (MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
add_compile_definitions(_METAL_)
# Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
@@ -84,23 +75,20 @@ elseif (MLX_BUILD_METAL)
COMMAND_ERROR_IS_FATAL ANY)
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
set(MLX_METAL_VERSION METAL_3_1)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
set(MLX_METAL_VERSION METAL_3_0)
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
else()
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
endif()
FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true
)
FetchContent_MakeAvailable(metal_cpp)
@@ -110,93 +98,50 @@ elseif (MLX_BUILD_METAL)
$<INSTALL_INTERFACE:include/metal_cpp>
)
target_link_libraries(
mlx PUBLIC
mlx
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
add_compile_definitions(${MLX_METAL_VERSION})
endif()
if (MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
endif()
find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
endif()
find_package(MPI)
if (MPI_FOUND)
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
#set(BLA_VENDOR Generic)
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS ${BLAS_LIBRARIES})
message(STATUS ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
mlx
mlx
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>
)
FetchContent_Declare(fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL
)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()

View File

@@ -1,4 +1,3 @@
include CMakeLists.txt
recursive-include mlx/ *
include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -6,17 +6,15 @@
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning research on Apple silicon,
brought to you by Apple machine learning research.
MLX is an array framework for machine learning on Apple silicon, brought to you
by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
- **Familiar APIs**: MLX has a Python API that closely follows NumPy.
MLX also has a fully featured C++ API, which closely mirrors the Python API.
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
that closely follow PyTorch to simplify building more complex models.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
@@ -70,31 +68,23 @@ in the documentation.
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
**With `pip`**:
```
pip install mlx
```
**With `conda`**:
```
conda install -c conda-forge mlx
```
Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source.
## Contributing
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
on contributing to MLX. See the
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
information on building from source, and running tests.
We are grateful for all of [our
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX and wish to be acknowledged, please add your name to the list in your
pull request.

View File

@@ -73,7 +73,6 @@ void time_unary_ops() {
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
@@ -85,9 +84,7 @@ void time_binary_ops() {
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(where, condition, a, b, device);
condition = array({true});
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
@@ -96,9 +93,7 @@ void time_binary_ops() {
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device);
condition = broadcast_to(array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
@@ -106,7 +101,6 @@ void time_binary_ops() {
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
}
void time_strided_ops() {

View File

@@ -17,13 +17,14 @@
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
<< std::flush << std::setprecision(5) \
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " \
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
template <typename F, typename... Args>
double time_fn(F fn, Args&&... args) {
double time_fn(F fn, Args... args) {
// warmup
for (int i = 0; i < 5; ++i) {
eval(fn(std::forward<Args>(args)...));

View File

@@ -72,9 +72,6 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
quant_matmul = {
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
@@ -87,15 +84,6 @@ quant_matmul = {
"quant_matmul_128_8": partial(
_quant_matmul, transpose=False, group_size=128, bits=8
),
"quant_matmul_t_32_2": partial(
_quant_matmul, transpose=True, group_size=32, bits=2
),
"quant_matmul_t_32_4": partial(
_quant_matmul, transpose=True, group_size=32, bits=4
),
"quant_matmul_t_32_8": partial(
_quant_matmul, transpose=True, group_size=32, bits=8
),
"quant_matmul_t_64_2": partial(
_quant_matmul, transpose=True, group_size=64, bits=2
),
@@ -380,6 +368,10 @@ if __name__ == "__main__":
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.cpu:
mx.set_default_device(mx.cpu)
else:
@@ -402,10 +394,6 @@ if __name__ == "__main__":
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))

View File

@@ -331,6 +331,10 @@ if __name__ == "__main__":
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps"
@@ -350,10 +354,6 @@ if __name__ == "__main__":
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))

View File

@@ -80,8 +80,10 @@ if __name__ == "__main__":
_filter = make_predicate(args.filter, args.negative_filter)
if args.mlx_dtypes:
compare_filtered = lambda x: (
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
compare_filtered = (
lambda x: compare_mlx_dtypes(
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
)
if _filter(x)
else None
)

View File

@@ -1,109 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import math
import random
import mlx.core as mx
from time_utils import time_fn
def bench_gelu():
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
x = mx.random.uniform(shape=(1000, 1024))
def gen_fun(fun):
def bench_fun(x):
for _ in range(10):
x = fun(x)
return x
return bench_fun
time_fn(gen_fun(gelu), x, msg="fixed gelu")
time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")
def randint():
return random.randint(1, x.shape[0])
def gen_fun(fun):
def bench_fun(x, y):
x = x[: randint()]
for _ in range(10):
x = fun(x)
y = fun(y)
return x, y
return bench_fun
y = mx.random.uniform(shape=(1000, 1024))
time_fn(gen_fun(gelu), x, y, msg="variable gelu")
time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
time_fn(
gen_fun(mx.compile(gelu, shapeless=True)),
x,
y,
msg="shapeless variable gelu",
)
def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias)
def layernorm(x):
x = x.astype(mx.float32)
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + 1e-4)
x = x.astype(mx.float16)
return weight * x + bias
x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)
def gen_fun(fun):
def bench_fun(x):
for _ in range(10):
x = fun(x)
return x
return bench_fun
time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")
def randint():
return random.randint(1, x.shape[0])
def gen_fun(fun):
def bench_fun(x):
x = x[: randint()]
for _ in range(10):
x = fun(x)
return x
return bench_fun
random.seed(0)
time_fn(gen_fun(layernorm), x, msg="variable layernorm")
random.seed(0)
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
random.seed(0)
time_fn(
gen_fun(mx.compile(layernorm, shapeless=True)),
x,
msg="shapeless variable layernorm",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Compile benchmarks.")
args = parser.parse_args()
bench_gelu()
bench_layernorm()

View File

@@ -1,123 +0,0 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_1D(strides=1, padding=0, groups=1):
def mx_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_1D
def make_pt_conv_1D(strides=1, padding=0, groups=1):
@torch.no_grad()
def pt_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_1D
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
scale = 1.0 / math.sqrt(wH * C)
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_1D(strides, padding, groups)
f_pt = make_pt_conv_1D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv1d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 5, 32, 1, 2, 1),
(4, 32, 32, 5, 32, 1, 2, 2),
(4, 32, 32, 5, 32, 1, 2, 4),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 16),
(4, 32, 32, 5, 32, 1, 2, 32),
(4, 32, 256, 5, 512, 1, 2, 2),
(4, 32, 256, 5, 512, 1, 2, 128),
(4, 32, 256, 5, 512, 1, 2, 256),
)
for dtype in dtypes:
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
for N, iH, C, wH, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, iH, C, wH, O, strides, padding, np_dtype, groups
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,136 +0,0 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,57 +0,0 @@
# Copyright © 2024 Apple Inc.
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def bandwidth_gb(runtime_ms, system_size):
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
bytes_per_gb = 1e9
ms_per_s = 1e3
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size):
def fft(x):
out = mx.fft.fft(x)
mx.eval(out)
return out
bandwidths = []
for k in range(4, 12):
n = 2**k
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
x = x.astype(mx.complex64)
mx.eval(x)
runtime_ms = measure_runtime(fft, x=x)
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
return bandwidths
def time_fft():
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=int(2**22))
with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=int(2**29))
# plot bandwidths
x = [2**k for k in range(4, 12)]
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
plt.title("MLX FFT Benchmark")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig("fft_plot.png")
if __name__ == "__main__":
time_fft()

View File

@@ -1,53 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from time import time
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_gather_mlx(x_shape, idx_shape):
def gather(x, idx):
mx.eval(x[idx])
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
x = mx.random.normal(x_shape).astype(mx.float32)
runtime = measure_runtime(gather, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_gather_torch(x_shape, idx_shape, device):
def gather(x, idx, device):
_ = x[idx]
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
x = torch.randn(x_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
idx_shapes = [(1_000_000,), (100_000,), ()]
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_gather_mlx(x_shape, idx_shape)
benchmark_gather_torch(x_shape, idx_shape, device=device)

View File

@@ -1,41 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def layer_norm(x, w, b, eps):
ot = x.dtype
x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True)
return (x - mu) * mx.rsqrt(v + eps) * w + b
def time_layer_norm():
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
if __name__ == "__main__":
time_layer_norm()

View File

@@ -0,0 +1,198 @@
# Copyright © 2023 Apple Inc.
import math
import time
import jax
import jax.numpy as jnp
from flax import linen as nn
class RoPE(nn.Module):
dims: int
traditional: bool = False
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = jnp.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
dtype=jnp.float32,
):
D = D // 2
positions = jnp.arange(offset, N, dtype=dtype)
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)
return costheta, sintheta
@nn.compact
def __call__(self, x, offset: int = 0):
shape = x.shape
x = x.reshape((-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.reshape(shape)
class LlamaAttention(nn.Module):
dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
num_heads = self.num_heads
dims = self.dims
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = jnp.concatenate([key_cache, keys], axis=2)
values = jnp.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = jax.nn.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
dims: int
mlp_dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
dims = self.dims
mlp_dims = self.mlp_dims
num_heads = self.num_heads
self.attention = LlamaAttention(dims, num_heads, dtype)
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = jax.nn.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
dtype = jnp.float16
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
x = jax.random.normal(k1, (1, 1, D), dtype)
cache = [
jax.random.normal(k2, [1, H, C, D // H], dtype),
jax.random.normal(k3, [1, H, C, D // H], dtype),
]
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
params = layer.init(k4, x, mask=None, cache=cache)["params"]
@jax.jit
def model_fn(x, mask, cache):
return layer.apply({"params": params}, x, mask=mask, cache=cache)
T = measure(model_fn, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,118 @@
# Copyright © 2023 Apple Inc.
import math
import time
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, False)
self.key_proj = nn.Linear(dims, dims, False)
self.value_proj = nn.Linear(dims, dims, False)
self.out_proj = nn.Linear(dims, dims, False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, False)
self.linear2 = nn.Linear(dims, mlp_dims, False)
self.linear3 = nn.Linear(mlp_dims, dims, False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
mx.eval(y, c)
start = time.time()
rs = []
for i in range(5):
y, c = model(x, mask=None, cache=cache)
rs.append((y, c))
mx.eval(rs)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
mx.set_default_device(mx.gpu)
dtype = mx.float16
layer = LlamaEncoderLayer(D, F, H)
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
x = mx.random.normal([1, 1, D], dtype=dtype)
cache = [
mx.random.normal([1, H, C, D // H], dtype=dtype),
mx.random.normal([1, H, C, D // H], dtype=dtype),
]
mx.eval(x, cache)
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,199 @@
# Copyright © 2023 Apple Inc.
import math
import time
import torch
import torch.mps
import torch.nn as nn
def sync_if_needed(x):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
class RoPE(nn.Module):
def __init__(self, dims: int, traditional: bool = False):
super().__init__()
self.dims = dims
self.traditional = traditional
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
else:
rx = torch.cat([rx1, rx2], dim=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
return rx
def forward(self, x, offset: int = 0):
shape = x.shape
x = x.view(-1, shape[-2], shape[-1])
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.view(*shape)
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
device="cpu",
dtype=torch.float32,
):
D = D // 2
positions = torch.arange(offset, N, dtype=dtype, device=device)
freqs = torch.exp(
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
)
theta = positions.view(-1, 1) * freqs.view(1, -1)
costheta = torch.cos(theta)
sintheta = torch.sin(theta)
return costheta, sintheta
class RMSNorm(nn.Module):
def __init__(self, dims: int, epsilon: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones((dims,)))
self.epsilon = epsilon
def forward(self, x):
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
return self.gamma * x * n
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def forward(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = torch.cat([key_cache, keys], dim=2)
values = torch.cat([value_cache, values], dim=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = torch.softmax(scores, dim=-1)
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = RMSNorm(dims)
self.norm2 = RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def forward(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = torch.nn.functional.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
@torch.no_grad()
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
device = torch.device("mps")
dtype = torch.float16
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
x = torch.randn(1, 1, D).to(device).to(dtype)
cache = [
torch.randn(1, H, C, D // H).to(device).to(dtype),
torch.randn(1, H, C, D // H).to(device).to(dtype),
]
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -1,39 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def rms_norm(x, w, eps):
ot = x.dtype
x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (x * n).astype(ot) * w
def time_rms_norm():
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1))
g2 = mx.grad(f2, argnums=(0, 1))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x, w):
gx, gw = x, w
for _ in range(32):
gx, gw = g(gx, gw, y)
return gx, gw
time_fn(rms_norm_loop, g1, x, w)
time_fn(rms_norm_loop, g2, x, w)
time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w)
if __name__ == "__main__":
time_rms_norm()

View File

@@ -1,35 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def time_rope():
rope = nn.RoPE(64)
# vec
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x, offset=100)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_mat, x)
if __name__ == "__main__":
time_rope()

View File

@@ -1,96 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx):
dst[*idx] = x
mx.eval(dst)
idx = []
for idx_shape in idx_shapes:
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
x = mx.random.normal(x_shape).astype(mx.float32)
dst = mx.random.normal(dst_shape).astype(mx.float32)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def gather(dst, x, idx, device):
dst[*idx] = x
if device == torch.device("mps"):
torch.mps.synchronize()
idx = []
for idx_shape in idx_shapes:
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
dst_shapes = [
(10, 64),
(100_000, 64),
(1_000_000, 64),
(100_000,),
(2_000_00,),
(20_000_000,),
(10000, 64),
(100, 64),
(100, 10_000, 64),
(10, 100, 100, 21),
(1_000, 1_000, 10),
]
idx_shapes = [
[(1_000_000,)],
[(1_000_000,)],
[(100_000,)],
[(1_000_000,)],
[(20_000_000,)],
[(20_000_000,)],
[(1000000,)],
[(10000000,)],
[(1_000,)],
[(10_000,)],
[(1_000,), (1_000,)],
]
x_shapes = [
(1_000_000, 64),
(1_000_000, 64),
(100_000, 64),
(1_000_000,),
(20_000_000,),
(20_000_000,),
(1000000, 64),
(10000000, 64),
(1_000, 10_000, 64),
(10_000, 100, 100, 21),
(1_000, 10),
]
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@@ -44,13 +44,6 @@ def time_matmul():
time_fn(mx.matmul, a, b)
def time_maximum():
a = mx.random.uniform(shape=(32, 1024, 1024))
b = mx.random.uniform(shape=(32, 1024, 1024))
mx.eval(a, b)
time_fn(mx.maximum, a, b)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -108,7 +101,6 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_maximum()
time_exp()
time_negative()
time_logsumexp()

View File

@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
# Copyright © 2023 Apple Inc.
import time
@@ -6,11 +6,7 @@ import mlx.core as mx
def time_fn(fn, *args, **kwargs):
msg = kwargs.pop("msg", None)
if msg:
print(f"Timing {msg} ...", end=" ")
else:
print(f"Timing {fn.__name__} ...", end=" ")
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
@@ -24,15 +20,3 @@ def time_fn(fn, *args, **kwargs):
msec = 1e3 * (toc - tic) / num_iters
print(f"{msec:.5f} msec")
def measure_runtime(fn, **kwargs):
# Warmup
for _ in range(5):
fn(**kwargs)
tic = time.time()
iters = 100
for _ in range(iters):
fn(**kwargs)
return (time.time() - tic) * 1000 / iters

View File

@@ -1,36 +0,0 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
@@ -1906,6 +1906,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

View File

@@ -1,36 +0,0 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
@@ -1918,6 +1918,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

1
docs/.gitignore vendored
View File

@@ -1,3 +1,2 @@
src/python/_autosummary*/
src/python/nn/_autosummary*/
src/python/optimizers/_autosummary*/

View File

@@ -1,50 +0,0 @@
################################################################################
# Primary project setup. #
################################################################################
PROJECT_NAME = "MLX"
OUTPUT_DIRECTORY = build
XML_OUTPUT = xml
HTML_OUTPUT = html
STRIP_FROM_PATH = ../
INPUT = ../mlx
FILE_PATTERNS = *.h
EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = YES
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES
################################################################################
# Doxygen preprocessor / parser control. #
################################################################################
ENABLE_PREPROCESSING = YES
MACRO_EXPANSION = YES
EXPAND_ONLY_PREDEF = NO
SKIP_FUNCTION_MACROS = NO
################################################################################
# Compound extraction control. #
################################################################################
EXTRACT_ALL = YES
EXTRACT_PACKAGE = YES
EXTRACT_STATIC = YES
CASE_SENSE_NAMES = NO
################################################################################
# Docstring control / customization. #
################################################################################
JAVADOC_AUTOBRIEF = YES
################################################################################
# Warning suppression. #
################################################################################
QUIET = YES
WARN_IF_UNDOCUMENTED = NO

View File

@@ -2,16 +2,12 @@
### Setup (do once)
Install Doxygen:
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
for example with `conda`:
```
brew install doxygen
```
Install Python packages:
```
pip install -r requirements.txt
conda install sphinx
pip install sphinx-book-theme
```
### Build
@@ -19,7 +15,7 @@ pip install -r requirements.txt
Build the docs from `mlx/docs/`
```
doxygen && make html
make html
```
View the docs by running a server in `mlx/docs/build/html/`:

View File

@@ -1,3 +0,0 @@
sphinx
breathe
sphinx-book-theme

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 746 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 76 KiB

After

Width:  |  Height:  |  Size: 7.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 48 KiB

View File

@@ -4,17 +4,16 @@
.. autoclass:: {{ objname }}
{% block methods %}
{#{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != "__init__" %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}
{% endblock %}#}

View File

@@ -12,7 +12,7 @@ import mlx.core as mx
project = "MLX"
copyright = "2023, MLX Contributors"
author = "MLX Contributors"
version = ".".join(mx.__version__.split(".")[:3])
version = ".".join(mx.__version__.split()[:-1])
release = version
# -- General configuration ---------------------------------------------------
@@ -22,28 +22,22 @@ extensions = [
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"breathe",
]
python_use_unqualified_type_names = True
autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"https://docs.python.org/3": None,
"https://numpy.org/doc/stable/": None,
}
breathe_projects = {"mlx": "../build/xml"}
breathe_default_project = "mlx"
templates_path = ["_templates"]
html_static_path = ["_static"]
source_suffix = ".rst"
main_doc = "index"
master_doc = "index"
highlight_language = "python"
pygments_style = "sphinx"
add_module_names = False
# -- Options for HTML output -------------------------------------------------
@@ -54,32 +48,11 @@ html_theme_options = {
"repository_url": "https://github.com/ml-explore/mlx",
"use_repository_button": True,
"navigation_with_keys": False,
"logo": {
"image_light": "_static/mlx_logo.png",
"image_dark": "_static/mlx_logo_dark.png",
},
}
html_logo = "_static/mlx_logo.png"
# -- Options for HTMLHelp output ---------------------------------------------
htmlhelp_basename = "mlx_doc"
def setup(app):
from sphinx.util import inspect
wrapped_isfunc = inspect.isfunction
def isfunc(obj):
type_name = str(type(obj))
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
return True
return wrapped_isfunc(obj)
inspect.isfunction = isfunc
# -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]

View File

@@ -3,5 +3,4 @@
Operations
==========
.. doxygengroup:: ops
:content-only:

View File

@@ -1,16 +1,24 @@
Custom Extensions in MLX
========================
Developer Documentation
=======================
You can extend MLX with custom operations on the CPU or GPU. This guide
explains how to do that with a simple example.
MLX provides a open and flexible backend to which users may add operations
and specialized implementations without much hassle. While the library supplies
efficient operations that can be used and composed for any number of
applications, there may arise cases where new functionalities or highly
optimized implementations are needed. For such cases, you may design and
implement your own operations that link to and build on top of :mod:`mlx.core`.
We will introduce the inner-workings of MLX and go over a simple example to
learn the steps involved in adding new operations to MLX with your own CPU
and GPU implementations.
Introducing the Example
-----------------------
Let's say you would like an operation that takes in two arrays, ``x`` and
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
and then adds them together to get the result ``z = alpha * x + beta * y``.
You can do that in MLX directly:
Let's say that you would like an operation that takes in two arrays,
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
respectively, and then adds them together to get the result
``z = alpha * x + beta * y``. Well, you can very easily do that by just
writing out a function as follows:
.. code-block:: python
@@ -19,35 +27,44 @@ You can do that in MLX directly:
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
This function performs that operation while leaving the implementation and
function transformations to MLX.
This function performs that operation while leaving the implementations and
differentiation to MLX.
However you may need to customize the underlying implementation, perhaps to
make it faster or for custom differentiation. In this tutorial we will go
through adding custom extensions. It will cover:
However, you work with vector math libraries often and realize that the
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
You would really like the part of your applications that does this operation
on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.
* The structure of the MLX library.
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
* Implementing a GPU operation using metal.
* Adding the ``vjp`` and ``jvp`` function transformation.
* Building a custom extension and binding it to python.
Well, what a coincidence! You are in the right place. Over the course of this
example, we will learn:
* The structure of the MLX library from the frontend API to the backend implementations.
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
* How to implement your own GPU implementation using metal.
* How to add your own ``vjp`` and ``jvp``.
* How to build your implementations, link them to MLX, and bind them to python.
Operations and Primitives
-------------------------
Operations in MLX build the computation graph. Primitives provide the rules for
evaluating and transforming the graph. Let's start by discussing operations in
more detail.
In one sentence, operations in MLX build the computation graph, and primitives
provide the rules for evaluation and transformations of said graph. Let's start
by discussing operations in more detail.
Operations
^^^^^^^^^^^
Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
Operations are the frontend functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
operations in the Python API (:ref:`ops`).
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++:
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
C++ API:
.. code-block:: C++
@@ -66,7 +83,10 @@ C++:
StreamOrDevice s = {} // Stream on which to schedule the operation
);
The simplest way to this operation is in terms of existing operations:
This operation itself can call other operations within it if needed. So, the
simplest way to go about implementing this operation would be do so in terms
of existing operations.
.. code-block:: C++
@@ -80,23 +100,25 @@ The simplest way to this operation is in terms of existing operations:
// Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s);
// Add and return
return add(ax, by, s);
}
The operations themselves do not contain the implementations that act on the
data, nor do they contain the rules of transformations. Rather, they are an
easy to use interface that use :class:`Primitive` building blocks.
However, as we discussed earlier, this is not our goal. The operations themselves
do not contain the implementations that act on the data, nor do they contain the
rules of transformations. Rather, they are an easy to use interface that build
on top of the building blocks we call :class:`Primitive`.
Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
more concrete:
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create an output given a set of input :class:`array` . Further,
a :class:`Primitive` is a class that contains rules on how it is evaluated
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
``jvp``. These words on their own can be a bit abstract, so lets take a step
back and go to our example to give ourselves a more concrete image.
.. code-block:: C++
@@ -112,15 +134,11 @@ more concrete:
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
array jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
@@ -129,8 +147,7 @@ more concrete:
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
const std::vector<int>& argnums) override;
/**
* The primitive must know how to vectorize itself across
@@ -138,7 +155,7 @@ more concrete:
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
@@ -158,22 +175,22 @@ more concrete:
void eval(const std::vector<array>& inputs, array& out);
};
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
implementations of how the output array is produced given the inputs through
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
:meth:`Axpby::vmap`.
The :class:`Axpby` class derives from the base :class:`Primitive` class and
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
``beta`` as parameters. It then provides implementations of how the array ``out``
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
Using the Primitive
^^^^^^^^^^^^^^^^^^^
Using the Primitives
^^^^^^^^^^^^^^^^^^^^^
Operations can use this :class:`Primitive` to add a new :class:`array` to the
computation graph. An :class:`array` can be constructed by providing its data
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
inputs that are passed to the primitive.
Operations can use this :class:`Primitive` to add a new :class:`array` to
the computation graph. An :class:`array` can be constructed by providing its
data type, shape, the :class:`Primitive` that computes it, and the
:class:`array` inputs that are passed to the primitive.
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
.. code-block:: C++
@@ -206,7 +223,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
std::make_unique<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
@@ -221,26 +238,27 @@ This operation now handles the following:
Implementing the Primitive
--------------------------
No computation happens when we call the operation alone. The operation only
builds the computation graph. When we evaluate the output array, MLX schedules
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
No computation happens when we call the operation alone. In effect, the
operation only builds the computation graph. When we evaluate the output
array, MLX schedules the execution of the computation graph, and calls
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
stream/device specified by the user.
.. warning::
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
no memory has been allocated for the output array. It falls on the implementation
of these functions to allocate memory as needed.
of these functions to allocate memory as needed
Implementing the CPU Back-end
Implementing the CPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by implementing a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Let's start by trying to implement a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
point-wise. This is captured in the templated function :meth:`axpby_impl`.
Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
pointwise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++
@@ -278,19 +296,19 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
}
}
Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
``complex64``. We throw an error if we encounter an unexpected type.
Now, we would like our implementation to be able to do this pointwise operation
for all incoming floating point arrays. Accordingly, we add dispatches for
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
if we encounter an unexpected type.
.. code-block:: C++
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
@@ -303,26 +321,28 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"[Axpby] Only supports floating point types.");
"Axpby is only supported for floating point types.");
}
}
This is good as a fallback implementation. We can use the ``axpby`` routine
provided by the Accelerate_ framework for a faster implementation in certain
cases:
We have a fallback implementation! Now, to do what we are really here to do.
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
framework? Well, there are 3 complications to keep in mind:
#. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only use it for ``float32`` types.
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
elements have fixed strides between them. We only direct to Accelerate
if both ``x`` and ``y`` are row contiguous or column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
MLX expects to write the output to a new array. We must copy the elements
of ``y`` into the output and use that as an input to ``axpby``.
floats. We can only direct to it for ``float32`` types
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
have fixed strides between them. Possibly due to broadcasts and transposes,
we aren't guaranteed that the inputs fit this requirement. We can
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
MLX expects to write out the answer to a new array. We must copy the elements
of ``y`` into the output array and use that as an input to ``axpby``
Let's write an implementation that uses Accelerate in the right conditions.
It allocates data for the output, copies ``y`` into it, and then calls the
:func:`catlas_saxpby` from accelerate.
Let's write out an implementation that uses Accelerate in the right conditions.
It must simply allocate data for the output, copy elements of ``y`` into it,
and then call the :meth:`catlas_saxpby` from accelerate.
.. code-block:: C++
@@ -336,7 +356,17 @@ It allocates data for the output, copies ``y`` into it, and then calls the
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
@@ -359,20 +389,18 @@ It allocates data for the output, copies ``y`` into it, and then calls the
/* INCY = */ 1);
}
For inputs that do not fit the criteria for accelerate, we fall back to
:meth:`Axpby::eval`. With this in mind, let's finish our
:meth:`Axpby::eval_cpu`.
Great! But what about the inputs that do not fit the criteria for accelerate?
Luckily, we can always just direct back to :meth:`Axpby::eval`.
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
.. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
@@ -382,33 +410,35 @@ For inputs that do not fit the criteria for accelerate, we fall back to
return;
}
// Fall back to common back-end if specializations are not available
eval(inputs, outputs);
// Fall back to common backend if specializations are not available
eval(inputs, out);
}
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
We have now hit a milestone! Just this much is enough to run the operation
:meth:`axpby` on a CPU stream!
Implementing the GPU Back-end
If you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
Implementing the GPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Apple silicon devices address their GPUs using the Metal_ shading language, and
GPU kernels in MLX are written using Metal.
Apple silicon devices address their GPUs using the Metal_ shading language, and
all GPU kernels in MLX are written using metal.
.. note::
Here are some helpful resources if you are new to Metal:
Here are some helpful resources if you are new to metal!
* A walkthrough of the metal compute pipeline: `Metal Example`_
* Documentation for metal shading language: `Metal Specification`_
* Using metal from C++: `Metal-cpp`_
Let's keep the GPU kernel simple. We will launch exactly as many threads as
there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the point-wise operation, and update its assigned
element in the output.
Let's keep the GPU algorithm simple. We will launch exactly as many threads
as there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
element in the output.
.. code-block:: C++
@@ -427,14 +457,15 @@ element in the output.
// Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output
out[index] =
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
We then need to instantiate this template for all floating point types and give
each instantiation a unique host name so we can identify it.
each instantiation a unique host name so we can identify the right kernel for
each data type.
.. code-block:: C++
@@ -457,21 +488,29 @@ each instantiation a unique host name so we can identify it.
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
will see later in :ref:`Building with CMake`. In the following example, we
assume that the library ``mlx_ext.metallib`` will always be co-located with
the executable/ shared-library calling the :meth:`register_library` function.
The :meth:`register_library` function takes the library's name and potential
path (or in this case, a function that can produce the path of the metal
library) and tries to load that library if it hasn't already been registered
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
it is important to package your C++ library with the metal library. We will
go over this process in more detail later.
The logic to determine the kernel, set the inputs, resolve the grid dimensions
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
below.
.. code-block:: C++
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
@@ -479,10 +518,10 @@ below.
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
// Allocate output memory
// Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel
// Resolve name of kernel (corresponds to axpby.metal)
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
@@ -494,7 +533,7 @@ below.
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
@@ -503,17 +542,17 @@ below.
size_t nelem = out.size();
// Encode input arrays to kernel
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
// Encode output arrays to kernel
compute_encoder.set_output_array(out, 2);
set_array_buffer(compute_encoder, out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim
// Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
@@ -531,30 +570,33 @@ below.
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
A few things to note about MLX and Metal before moving on. MLX keeps track of
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
associated. We rely on :meth:`d.get_command_encoder` to give us the active
metal compute command encoder instead of building a new one and calling
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
pipelines) to the active command buffer until some specified limit is hit or
the command buffer needs to be flushed for synchronization.
A few things to note about MLX and metal before moving on. MLX keeps track
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
to give us the active metal compute command encoder instead of building a
new one and calling :meth:`compute_encoder->end_encoding` at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
for synchronization. MLX also handles enqueuing and committing the associated
command buffers as needed. We suggest taking a deeper dive into
:class:`metal::Device` if you would like to study this routine further.
Primitive Transforms
^^^^^^^^^^^^^^^^^^^^^
Next, let's add implementations for transformations in a :class:`Primitive`.
These transformations can be built on top of other operations, including the
one we just defined:
Now that we have come this far, let's also learn how to add implementations to
transformations in a :class:`Primitive`. These transformations can be built on
top of our operations, including the one we just defined now. Which then gives
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
.. code-block:: C++
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
array Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
@@ -569,12 +611,12 @@ one we just defined:
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
return multiply(scale_arr, tangents[0], stream());
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
}
}
@@ -583,35 +625,34 @@ one we just defined:
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<int>& /* unused */) {
const array& cotan,
const std::vector<int>& argnums) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
auto scale_arr = array(scale, cotan.dtype());
vjps.push_back(multiply(scale_arr, cotan, stream()));
}
return vjps;
}
Note, a transformation does not need to be fully defined to start using
the :class:`Primitive`.
Finally, you need not have a transformation fully defined to start using your
own :class:`Primitive`.
.. code-block:: C++
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("[Axpby] vmap not implemented.");
throw std::runtime_error("Axpby has no vmap implementation.");
}
Building and Binding
--------------------
Let's look at the overall directory structure first.
Let's look at the overall directory structure first.
| extensions
| ├── axpby
@@ -625,39 +666,40 @@ Let's look at the overall directory structure first.
| └── setup.py
* ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated Python package
* ``extensions/bindings.cpp`` provides Python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
Python bindings
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated python package
* ``extensions/bindings.cpp`` provides python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
python bindings
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
the Python package
the python package
Binding to Python
^^^^^^^^^^^^^^^^^^
We use nanobind_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple.
We use PyBind11_ to build a Python API for the C++ library. Since bindings
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
are already provided, adding our :meth:`axpby` becomes very simple!
.. code-block:: C++
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
@@ -669,17 +711,17 @@ already provided, adding our :meth:`axpby` is simple.
Returns:
array: ``alpha * x + beta * y``
)");
)pbdoc");
}
Most of the complexity in the above example comes from additional bells and
Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings.
.. warning::
:mod:`mlx.core` must be imported before importing
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
ensure that the casters for :mod:`mlx.core` components like
:mod:`mlx.core` needs to be imported before importing
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
ensure that the casters for :mod:`mlx.core` components like
:class:`mlx.core.array` are available.
.. _Building with CMake:
@@ -687,8 +729,8 @@ whistles such as the literal names and doc-strings.
Building with CMake
^^^^^^^^^^^^^^^^^^^^
Building the C++ extension library only requires that you ``find_package(MLX
CONFIG)`` and then link it to your library.
Building the C++ extension library itself is simple, it only requires that you
``find_package(MLX CONFIG)`` and then link it to your library.
.. code-block:: cmake
@@ -710,12 +752,12 @@ CONFIG)`` and then link it to your library.
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
We also need to build the attached Metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
We also need to build the attached metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
Here is what that looks like in practice:
Here is what that looks like in practice!
.. code-block:: cmake
@@ -737,29 +779,27 @@ Here is what that looks like in practice:
endif()
Finally, we build the nanobind_ bindings
Finally, we build the Pybind11_ bindings
.. code-block:: cmake
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
pybind11_add_module(
mlx_sample_extensions
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()
Building with ``setuptools``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once we have set out the CMake build rules as described above, we can use the
build utilities defined in :mod:`mlx.extension`:
build utilities defined in :mod:`mlx.extension` for a simple build process.
.. code-block:: python
.. code-block:: python
from mlx import extension
from setuptools import setup
@@ -769,50 +809,48 @@ build utilities defined in :mod:`mlx.extension`:
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev":[]},
packages = ["mlx_sample_extensions"],
package_dir = {"": "mlx_sample_extensions"},
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
zip_safe=False,
python_requires=">=3.8",
python_requires=">=3.7",
)
.. note::
We treat ``extensions/mlx_sample_extensions`` as the package directory
even though it only contains a ``__init__.py`` to ensure the following:
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
To build the package, first install the build dependencies with ``pip install
-r requirements.txt``. You can then build inplace for development using
You can build inplace for development using
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
This results in the directory structure:
This will result in a directory structure as follows:
| extensions
| ├── mlx_sample_extensions
| │ ├── __init__.py
| │ ├── libmlx_ext.dylib # C++ extension library
| │ ├── mlx_ext.metallib # Metal library
| │ └── _ext.cpython-3x-darwin.so # Python Binding
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
| ...
When you try to install using the command ``python -m pip install .`` (in
``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
copied along with the Python binding since they are specified as
``package_data``.
When you try to install using the command ``python -m pip install .``
(in ``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
copied along with the python binding since they are specified as ``package_data``.
Usage
-----
After installing the extension as described above, you should be able to simply
import the Python package and play with it as you would any other MLX operation.
After installing the extension as described above, you should be able to simply
import the python package and play with it as you would any other MLX operation!
Let's look at a simple script and its results:
Let's looks at a simple script and it's results!
.. code-block:: python
@@ -825,7 +863,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")
print(f"c correctness: {mx.all(c == 6.0).item()}")
Output:
@@ -836,12 +874,12 @@ Output:
c correctness: True
Results
^^^^^^^
^^^^^^^^^^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we first defined on the CPU.
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we defined at first on the CPU.
.. code-block:: python
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
@@ -860,7 +898,7 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
alpha = 4.0
beta = 2.0
mx.eval(x, y)
mx.eval((x, y))
def bench(f):
# Warm up
@@ -881,23 +919,30 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
modest improvements right away!
Results:
This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`.
.. code-block::
Simple axpby: 0.114 s | Custom axpby: 0.109 s
We see some modest improvements right away!
This operation is now good to be used to build other operations,
in :class:`mlx.nn.Module` calls, and also as a part of graph
transformations such as :meth:`grad` and :meth:`simplify`!
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
The full example code is available in `mlx-examples <code>`_.
.. code: `TODO_LINK/extensions`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
.. _nanobind: https://nanobind.readthedocs.io/en/latest/
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/

View File

@@ -1,68 +0,0 @@
Metal Debugger
==============
.. currentmodule:: mlx.core
Profiling is a key step for performance optimization. You can build MLX with
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
* Records source during Metal compilation, for later inspection while
debugging.
* Labels Metal objects such as command queues, improving capture readability.
To build with debugging enabled in Python prepend
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
work.
.. note::
To capture a GPU trace you must run the application with
``MTL_CAPTURE_ENABLED=1``.
.. code-block:: python
import mlx.core as mx
a = mx.random.uniform(shape=(512, 512))
b = mx.random.uniform(shape=(512, 512))
mx.eval(a, b)
trace_file = "mlx_trace.gputrace"
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
# that the path trace_file does not already exist.
mx.metal.start_capture(trace_file)
for _ in range(10):
mx.eval(mx.add(a, b))
mx.metal.stop_capture()
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
has a great overview of all operations. Checkout the `Metal debugger
documentation`_ for more information.
.. image:: ../_static/metal_debugger/capture.png
:class: dark-light
Xcode Workflow
--------------
You can skip saving to a path by running within Xcode. First, generate an
Xcode project using CMake.
.. code-block::
mkdir build && cd build
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
open mlx.xcodeproj
Select the ``metal_capture`` example schema and run.
.. image:: ../_static/metal_debugger/schema.png
:class: dark-light
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger

View File

@@ -41,7 +41,6 @@ are the CPU and GPU.
usage/indexing
usage/saving_and_loading
usage/function_transforms
usage/compile
usage/numpy
usage/using_streams
@@ -58,15 +57,12 @@ are the CPU and GPU.
:maxdepth: 1
python/array
python/data_types
python/devices_and_streams
python/ops
python/random
python/transforms
python/fast
python/fft
python/linalg
python/metal
python/nn
python/optimizers
python/tree_utils
@@ -82,4 +78,3 @@ are the CPU and GPU.
:maxdepth: 1
dev/extensions
dev/metal_debugger

View File

@@ -15,10 +15,10 @@ To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.8
- macOS >= 13.5
- macOS >= 13.3
.. note::
MLX is only available on devices running macOS >= 13.5
MLX is only available on devices running macOS >= 13.3
It is highly recommended to use macOS 14 (Sonoma)
@@ -54,7 +54,7 @@ Build Requirements
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- Xcode >= 15.0 and macOS SDK >= 14.0
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
.. note::
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
@@ -70,13 +70,16 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
.. code-block:: shell
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install "pybind11[global]"
conda install pybind11
brew install pybind11
Then simply build and install MLX using pip:
Then simply build and install it using pip:
.. code-block:: shell
@@ -120,7 +123,7 @@ Create a build directory and run CMake and make:
.. code-block:: shell
mkdir -p build && cd build
cmake .. && make -j
cmake .. && make -j
Run tests with:
@@ -139,7 +142,7 @@ directory as the executable statically linked to ``libmlx.a`` or the
preprocessor constant ``METAL_PATH`` should be defined at build time and it
should point to the path to the built metal library.
.. list-table:: Build Options
.. list-table:: Build Options
:widths: 25 8
:header-rows: 1
@@ -153,67 +156,31 @@ should point to the path to the built metal library.
- OFF
* - MLX_BUILD_METAL
- ON
* - MLX_BUILD_CPU
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
* - MLX_METAL_DEBUG
- OFF
* - MLX_BUILD_SAFETENSORS
- ON
* - MLX_BUILD_GGUF
- ON
* - MLX_METAL_JIT
- OFF
.. note::
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
.. code-block:: shell
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which
Further, you can use the following command to find out which
macOS SDK will be used
.. code-block:: shell
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
and `BUILD_SHARED_LIBS=ON`.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do:
.. code-block:: shell
cmake ..
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists accross reboots.
Troubleshooting
^^^^^^^^^^^^^^^
Metal not found
~~~~~~~~~~~~~~~
@@ -235,7 +202,7 @@ Then set the active developer directory:
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
x86 Shell
x86 Shell
~~~~~~~~~
.. _build shell:

View File

@@ -10,38 +10,27 @@ Array
array
array.astype
array.at
array.item
array.tolist
array.dtype
array.itemsize
array.nbytes
array.ndim
array.shape
array.size
Dtype
array.abs
array.all
array.any
array.argmax
array.argmin
array.cos
array.cummax
array.cummin
array.cumprod
array.cumsum
array.diag
array.diagonal
array.dtype
array.exp
array.flatten
array.log
array.log10
array.log1p
array.log2
array.logsumexp
array.max
array.mean
array.min
array.moveaxis
array.prod
array.reciprocal
array.reshape
@@ -51,8 +40,6 @@ Array
array.split
array.sqrt
array.square
array.squeeze
array.swapaxes
array.sum
array.transpose
array.T

View File

@@ -1,5 +1,7 @@
.. _data_types:
:orphan:
Data Types
==========
@@ -42,27 +44,9 @@ The default floating point type is ``float32`` and the default integer type is
* - ``int64``
- 8
- 64-bit signed integer
* - ``bfloat16``
- 2
- 16-bit brain float (e8, m7)
* - ``float16``
- 2
- 16-bit IEEE float (e5, m10)
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
* - ``float32``
- 4
- 32-bit float
* - ``complex64``
- 8
- 64-bit complex float
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category.
.. autosummary::
:toctree: _autosummary
Dtype
DtypeCategory
issubdtype

View File

@@ -9,11 +9,9 @@ Devices and Streams
:toctree: _autosummary
Device
Stream
default_device
set_default_device
Stream
default_stream
new_stream
set_default_stream
stream
synchronize

View File

@@ -1,14 +0,0 @@
.. _fast:
Fast
====
.. currentmodule:: mlx.core.fast
.. autosummary::
:toctree: _autosummary
rms_norm
layer_norm
rope
scaled_dot_product_attention

View File

@@ -8,8 +8,4 @@ Linear Algebra
.. autosummary::
:toctree: _autosummary
inv
norm
cholesky
qr
svd

View File

@@ -1,19 +0,0 @@
Metal
=====
.. currentmodule:: mlx.core.metal
.. autosummary::
:toctree: _autosummary
is_available
device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
clear_cache
start_capture
stop_capture

View File

@@ -173,7 +173,6 @@ In detail:
:toctree: _autosummary
value_and_grad
quantize
.. toctree::
@@ -181,4 +180,3 @@ In detail:
nn/layers
nn/functions
nn/losses
nn/init

View File

@@ -12,24 +12,12 @@ simple functions.
:toctree: _autosummary_functions
:template: nn-module-template.rst
elu
gelu
gelu_approx
gelu_fast_approx
glu
hardswish
leaky_relu
log_sigmoid
log_softmax
mish
prelu
relu
relu6
selu
sigmoid
silu
softmax
softplus
softshrink
step
tanh

View File

@@ -1,45 +0,0 @@
.. _init:
.. currentmodule:: mlx.nn.init
Initializers
------------
The ``mlx.nn.init`` package contains commonly used initializers for neural
network parameters. Initializers return a function which can be applied to any
input :obj:`mlx.core.array` to produce an initialized output.
For example:
.. code:: python
import mlx.core as mx
import mlx.nn as nn
init_fn = nn.init.uniform()
# Produces a [2, 2] uniform matrix
param = init_fn(mx.zeros((2, 2)))
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
distribution, you can do:
.. code:: python
import mlx.nn as nn
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
init_fn = nn.init.uniform(low=-0.1, high=0.1)
model.apply(init_fn)
.. autosummary::
:toctree: _autosummary
constant
normal
uniform
identity
glorot_normal
glorot_uniform
he_normal
he_uniform

View File

@@ -10,39 +10,28 @@ Layers
:template: nn-module-template.rst
ALiBi
AvgPool1d
AvgPool2d
BatchNorm
Conv1d
Conv2d
Conv3d
Dropout
Dropout2d
Dropout3d
Embedding
GELU
GroupNorm
GRU
InstanceNorm
LayerNorm
Linear
LSTM
MaxPool1d
MaxPool2d
Mish
MultiHeadAttention
PReLU
QuantizedEmbedding
QuantizedLinear
RMSNorm
ReLU
RNN
RoPE
SELU
Sequential
SiLU
SinusoidalPositionalEncoding
Softshrink
Step
Transformer
Upsample

View File

@@ -18,7 +18,6 @@ Loss Functions
kl_div_loss
l1_loss
log_cosh_loss
margin_ranking_loss
mse_loss
nll_loss
smooth_l1_loss

View File

@@ -11,7 +11,6 @@ Module
:toctree: _autosummary
Module.training
Module.state
.. rubric:: Methods
@@ -30,7 +29,6 @@ Module
Module.named_modules
Module.parameters
Module.save_weights
Module.set_dtype
Module.train
Module.trainable_parameters
Module.unfreeze

View File

@@ -5,14 +5,13 @@ Operations
.. currentmodule:: mlx.core
.. autosummary::
.. autosummary::
:toctree: _autosummary
abs
add
addmm
all
allclose
allclose
any
arange
arccos
@@ -20,67 +19,42 @@ Operations
arcsin
arcsinh
arctan
arctan2
arctanh
argmax
argmin
argpartition
argsort
array_equal
as_strided
atleast_1d
atleast_2d
atleast_3d
bitwise_and
bitwise_or
bitwise_xor
block_masked_mm
broadcast_to
ceil
clip
concatenate
conj
conjugate
convolve
conv1d
conv2d
conv_general
cos
cosh
cummax
cummin
cumprod
cumsum
degrees
dequantize
diag
diagonal
divide
divmod
equal
erf
erfinv
exp
expm1
expand_dims
eye
flatten
floor
floor_divide
full
gather_mm
gather_qmm
greater
greater_equal
identity
inner
isclose
isinf
isnan
isneginf
isposinf
issubdtype
left_shift
isneginf
isinf
less
less_equal
linspace
@@ -98,28 +72,22 @@ Operations
max
maximum
mean
meshgrid
min
minimum
moveaxis
multiply
negative
not_equal
ones
ones_like
outer
partition
pad
power
prod
quantize
quantized_matmul
radians
reciprocal
remainder
repeat
reshape
right_shift
round
rsqrt
save
@@ -138,7 +106,6 @@ Operations
square
squeeze
stack
std
stop_gradient
subtract
sum
@@ -148,9 +115,6 @@ Operations
tan
tanh
tensordot
tile
topk
trace
transpose
tri
tril

View File

@@ -1,7 +1,5 @@
.. _optimizers:
.. currentmodule:: mlx.optimizers
Optimizers
==========
@@ -31,13 +29,19 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
.. toctree::
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
clip_grad_norm
OptimizerState
Optimizer
SGD
RMSprop
Adagrad
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@@ -1,20 +0,0 @@
.. _common_optimizers:
Common Optimizers
=================
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@@ -1,23 +0,0 @@
Optimizer
=========
.. currentmodule:: mlx.optimizers
.. autoclass:: Optimizer
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Optimizer.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Optimizer.apply_gradients
Optimizer.init
Optimizer.update

View File

@@ -1,15 +0,0 @@
.. _schedulers:
Schedulers
==========
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
cosine_decay
exponential_decay
join_schedules
linear_schedule
step_decay

View File

@@ -38,7 +38,6 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
gumbel
key
normal
multivariate_normal
randint
seed
split

View File

@@ -9,11 +9,9 @@ Transforms
:toctree: _autosummary
eval
compile
disable_compile
enable_compile
grad
value_and_grad
jvp
vjp
vmap
simplify

View File

@@ -19,5 +19,3 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_flatten
tree_unflatten
tree_map
tree_map_with_path
tree_reduce

View File

@@ -1,430 +0,0 @@
.. _compile:
Compilation
===========
.. currentmodule:: mlx.core
MLX has a :func:`compile` function transformation which compiles computation
graphs. Function compilation results in smaller graphs by merging common work
and fusing certain operations. In many cases this can lead to big improvements
in run-time and memory use.
Getting started with :func:`compile` is simple, but there are some edge cases
that are good to be aware of for more complex graphs and advanced usage.
Basics of Compile
-----------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
function multiple times will not initiate a new compilation. This means you
should typically compile functions that you plan to use more than once.
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
compiled_fun = mx.compile(fun)
# Compiled here
compiled_fun(x, y)
# Not compiled again
compiled_fun(x, y)
# Not compiled again
mx.compile(fun)(x, y)
There are some important cases to be aware of that can cause a function to
be recompiled:
* Changing the shape or number of dimensions
* Changing the type of any of the inputs
* Changing the number of inputs to the function
In certain cases only some of the compilation stack will be rerun (for
example when changing the shapes) and in other cases the full compilation
stack will be rerun (for example when changing the types). In general you
should avoid compiling functions too frequently.
Another idiom to watch out for is compiling functions which get created and
destroyed frequently. This can happen, for example, when compiling an anonymous
function in a loop:
.. code-block:: python
a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
Example Speedup
---------------
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
Transformer-based models. The implementation involves several unary and binary
element-wise operations:
.. code-block:: python
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
use it with large arrays it will be memory bandwidth bound. However, all of
the operations in the ``gelu`` are fusible into a single kernel with
:func:`compile`. This can speedup both cases considerably.
Let's compare the runtime of the regular function versus the compiled
function. We'll use the following timing helper which does a warm up and
handles synchronization:
.. code-block:: python
import time
def timeit(fun, x):
# warm up
for _ in range(10):
mx.eval(fun(x))
tic = time.perf_counter()
for _ in range(100):
mx.eval(fun(x))
toc = time.perf_counter()
tpi = 1e3 * (toc - tic) / 100
print(f"Time per iteration {tpi:.3f} (ms)")
Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging
---------
When a compiled function is first called, it is traced with placeholder
inputs. This means you can't evaluate arrays (for example to print their
contents) inside compiled functions.
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Crash
return mx.exp(z)
fun(mx.array(5.0))
For debugging, inspecting arrays can be helpful. One way to do that is to
globally disable compilation using the :func:`disable_compile` function or
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
``fun`` is compiled:
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Okay
return mx.exp(z)
mx.disable_compile()
fun(mx.array(5.0))
Pure Functions
--------------
Compiled functions are intended to be *pure*; that is they should not have side
effects. For example:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)
After the first call of ``fun``, the ``state`` list will hold a placeholder
array. The placeholder does not have any data; it is only used to build the
computation graph. Printing such an array results in a crash.
You have two options to deal with this. The first option is to simply return
``state`` as an output:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
_, state = fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
In some cases returning updated state can be pretty inconvenient. Hence,
:func:`compile` has a parameter to capture implicit outputs:
.. code-block:: python
from functools import partial
state = []
# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
This is particularly useful for compiling a function which includes an update
to a container of arrays, as is commonly done when training the parameters of a
:class:`mlx.nn.Module`.
Compiled functions will also treat any inputs not in the parameter list as
constants. For example:
.. code-block:: python
state = [mx.array(1.0)]
@mx.compile
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
In order to have the change of state reflected in the outputs of ``fun`` you
again have two options. The first option is to simply pass ``state`` as input
to the function. In some cases this can be pretty inconvenient. Hence,
:func:`compile` also has a parameter to capture implicit inputs:
.. code-block:: python
from functools import partial
state = [mx.array(1.0)]
# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
of a common setup: training a model with :obj:`mlx.nn.Module` using an
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Perform 10 steps of gradient descent
for it in range(10):
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
# Perform 10 steps of gradient descent
for it in range(10):
loss = step(x, y)
# Evaluate the model and optimizer state
mx.eval(state)
print(loss)
.. note::
If you are using a module which performs random sampling such as
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
optimizer.state, mx.random.state]``.
.. note::
For more examples of compiling full training graphs checkout the `MLX
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
Transformations with Compile
----------------------------
In MLX function transformations are composable. You can apply any function
transformation to the output of any other function transformation. For more on
this, see the documentation on :ref:`function transforms
<function_transforms>`.
Compiling transformed functions works just as expected:
.. code-block:: python
grad_fn = mx.grad(mx.exp)
compiled_grad_fn = mx.compile(grad_fn)
# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))
# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))
.. note::
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`
the most opportunity to optimize the computation graph:
.. code-block:: python
@mx.compile
def inner(x):
return mx.exp(-mx.abs(x))
def outer(x):
inner(inner(x))
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)

View File

@@ -5,12 +5,9 @@ Function Transforms
.. currentmodule:: mlx.core
MLX uses composable function transformations for automatic differentiation,
vectorization, and compute graph optimizations. To see the complete list of
function transformations check-out the :ref:`API documentation <transforms>`.
The key idea behind composable function transformations is that every
transformation returns a function which can be further transformed.
MLX uses composable function transformations for automatic differentiation and
vectorization. The key idea behind composable function transformations is that
every transformation returns a function which can be further transformed.
Here is a simple example:
@@ -39,10 +36,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep
getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
depth. See the following sections for more information on :ref:`automatic
differentiation <auto diff>` and :ref:`automatic vectorization <vmap>`.
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
depth. To see the complete list of function transformations check-out the
:ref:`API documentation <transforms>`. See the following sections for more
information on :ref:`automatic differentiaion <auto diff>` and
:ref:`automatic vectorization <vmap>`.
Automatic Differentiation
-------------------------

View File

@@ -18,9 +18,9 @@ describe below.
Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Lazy evaluation lets us record a compute graph without actually doing any
Lazy evaluation let's us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations.
:func:`vmap` and graph optimizations like :func:`simplify`.
Currently, MLX does not compile and rerun compute graphs. They are all
generated dynamically. However, lazy evaluation makes it much easier to

View File

@@ -49,7 +49,7 @@ it will be added. You can load the array with:
.. code-block:: shell
>>> mx.load("array.npy")
>>> mx.load("array.npy", a)
array([1], dtype=float32)
Here's an example of saving several arrays to a single file:

View File

@@ -8,5 +8,3 @@ endfunction(build_example)
build_example(tutorial.cpp)
build_example(linear_regression.cpp)
build_example(logistic_regression.cpp)
build_example(metal_capture.cpp)
build_example(distributed.cpp)

View File

@@ -1,22 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
if (!distributed::is_available()) {
std::cout << "No communication backend found" << std::endl;
return 1;
}
auto global_group = distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
array x = ones({10});
array out = distributed::all_reduce_sum(x, global_group);
std::cout << out << std::endl;
}

View File

@@ -1,31 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
// To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1.
metal::start_capture("mlx_trace.gputrace");
// Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each
// stream's command queue.
auto s2 = new_stream(Device::gpu);
auto s3 = new_stream(Device::gpu);
auto a = arange(1.f, 10.f, 1.f, float32, s2);
auto b = arange(1.f, 10.f, 1.f, float32, s3);
auto x = add(a, a, s2);
auto y = add(b, b, s3);
// The multiply will happen on the default stream.
std::cout << multiply(x, y) << std::endl;
metal::stop_capture();
}

View File

@@ -89,8 +89,8 @@ void automatic_differentiation() {
// dfdx is 2 * x
// Get the second derivative by composing grad with grad
auto d2fdx2 = grad(grad(fn))(x);
// d2fdx2 is 2
auto df2dx2 = grad(grad(fn))(x);
// df2dx2 is 2
}
int main() {

View File

@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.27)
cmake_minimum_required(VERSION 3.24)
project(_ext LANGUAGES CXX)
project(mlx_sample_extensions LANGUAGES CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17)
@@ -11,12 +11,8 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
# ----------------------------- Extensions -----------------------------
@@ -42,6 +38,7 @@ target_link_libraries(mlx_ext PUBLIC mlx)
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
@@ -57,15 +54,13 @@ if(MLX_BUILD_METAL)
endif()
# ----------------------------- Python Bindings -----------------------------
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
# ----------------------------- Pybind -----------------------------
pybind11_add_module(
mlx_sample_extensions
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()

View File

@@ -1,24 +0,0 @@
## Build
```
pip install -e .
```
For faster builds during development, you can also pre-install the requirements:
```
pip install -r requirements.txt
```
And then run:
```
python setup.py build_ext -j8 --inplace
```
## Test
```
python test.py
```

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <iostream>
@@ -43,7 +43,7 @@ array axpby(
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32)
auto out_dtype = is_floating_point(promoted_dtype)
? promoted_dtype
: promote_types(promoted_dtype, float32);
@@ -61,7 +61,7 @@ array axpby(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
std::make_unique<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
@@ -106,12 +106,12 @@ void axpby_impl(
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
std::vector<array>& out_arr) {
auto out = out_arr[0];
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
@@ -150,7 +150,11 @@ void axpby_impl_accelerate(
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
@@ -176,11 +180,11 @@ void axpby_impl_accelerate(
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
std::vector<array>& outarr) {
auto out = outarr[0];
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
@@ -191,7 +195,7 @@ void Axpby::eval_cpu(
}
// Fall back to common backend if specializations are not available
eval(inputs, outputs);
eval(inputs, outarr);
}
#else // Accelerate not available
@@ -199,8 +203,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
eval(inputs, outputs);
std::vector<array>& out) {
eval(inputs, out);
}
#endif
@@ -214,12 +218,12 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
std::vector<array>& outarr) {
// Prepare inputs
auto out = outarr[0];
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
@@ -257,7 +261,7 @@ void Axpby::eval_gpu(
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
@@ -266,11 +270,11 @@ void Axpby::eval_gpu(
size_t nelem = out.size();
// Encode input arrays to kernel
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
// Encode output arrays to kernel
compute_encoder.set_output_array(out, 2);
set_array_buffer(compute_encoder, out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
@@ -296,7 +300,7 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
#else // Metal is not available
@@ -368,4 +372,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -33,7 +33,7 @@ array axpby(
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta) {};
: Primitive(stream), alpha_(alpha), beta_(beta){};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
@@ -42,9 +42,9 @@ class Axpby : public Primitive {
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out)
override;
/** The Jacobian-vector product. */
@@ -83,7 +83,7 @@ class Axpby : public Primitive {
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
void eval(const std::vector<array>& inputs, std::vector<array>& out);
};
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -19,7 +19,7 @@ template <typename T>
uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
out[index] =
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
@@ -31,30 +31,30 @@ template <typename T>
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
out[index] =
out[index] =
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
}
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \
[[kernel]] void axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] \
[[kernel]] void axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);

View File

@@ -1,31 +1,31 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/variant.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "axpby/axpby.h"
namespace nb = nanobind;
using namespace nb::literals;
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
@@ -37,5 +37,5 @@ NB_MODULE(_ext, m) {
Returns:
array: ``alpha * x + beta * y``
)");
}
)pbdoc");
}

View File

@@ -2,4 +2,4 @@
import mlx.core as mx
from ._ext import axpby
from .mlx_sample_extensions import *

View File

@@ -1,8 +1,3 @@
[build-system]
requires = [
"setuptools>=42",
"cmake>=3.24",
"mlx>=0.9.0",
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
]
build-backend = "setuptools.build_meta"
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"]
build-backend = "setuptools.build_meta"

View File

@@ -1,4 +0,0 @@
setuptools>=42
cmake>=3.24
mlx>=0.9.0
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4

View File

@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
# Copyright © 2023 Apple Inc.
from setuptools import setup
@@ -9,11 +9,11 @@ if __name__ == "__main__":
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_dir={"": "."},
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev": []},
zip_safe=False,
python_requires=">=3.8",
)

View File

@@ -1,10 +0,0 @@
import mlx.core as mx
from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")

View File

@@ -3,10 +3,9 @@ target_sources(
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
@@ -19,17 +18,11 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
)
if (MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
elseif(MLX_BUILD_CPU)
else()
target_sources(
mlx
PRIVATE

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
Buffer(void* ptr) : ptr_(ptr) {};
Buffer(void* ptr) : ptr_(ptr){};
// Get the raw data pointer from the buffer
void* raw_ptr();

View File

@@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <functional>
#include "mlx/array.h"
@@ -11,6 +12,16 @@ namespace mlx::core {
namespace {
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
std::vector<size_t> strides(shape.size());
size_t cum_prod = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = cum_prod;
cum_prod *= shape[i];
}
return {cum_prod, strides};
}
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
bool in_tracing() {
@@ -25,11 +36,22 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
init(&cval);
}
array::array(
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
shape,
dtype,
std::move(primitive),
inputs)) {}
array::array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
std::vector<array>&& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
std::move(shape),
dtype,
@@ -37,16 +59,15 @@ array::array(
std::move(inputs))) {}
std::vector<array> array::make_arrays(
std::vector<std::vector<int>> shapes,
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs) {
std::vector<array> outputs;
for (size_t i = 0; i < shapes.size(); ++i) {
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
for (int i = 0; i < shapes.size(); ++i) {
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
}
// For each node in |outputs|, its siblings are the other nodes.
for (size_t i = 0; i < outputs.size(); ++i) {
for (int i = 0; i < outputs.size(); ++i) {
auto siblings = outputs;
siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i);
@@ -61,20 +82,13 @@ array::array(std::initializer_list<float> data)
init(data.begin());
}
array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
dtype)) {
init(data.begin());
}
/* Build an array from a shared buffer */
array::array(
allocator::Buffer data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
deleter_t deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
set_data(data, deleter);
}
@@ -92,13 +106,7 @@ void array::detach() {
}
void array::eval() {
// Ensure the array is ready to be read
if (status() == Status::scheduled) {
event().wait();
set_status(Status::available);
} else if (status() == Status::unscheduled) {
mlx::core::eval({*this});
}
mlx::core::eval({*this});
}
bool array::is_tracer() const {
@@ -147,115 +155,38 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
: shape(shape), dtype(dtype) {
std::tie(size, strides) = cum_prod(shape);
}
void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
array::~array() {
if (array_desc_ == nullptr) {
return;
}
// Ignore arrays that will be detached
if (status() != array::Status::unscheduled) {
return;
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;
// If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1)
// then there are no more external references
do_detach &= (array_desc_.use_count() == (n + 1));
for (auto& s : siblings()) {
do_detach &= (s.array_desc_.use_count() == n);
if (!do_detach) {
break;
}
}
if (do_detach) {
for (auto& s : siblings()) {
for (auto& ss : s.siblings()) {
ss.array_desc_ = nullptr;
}
s.array_desc_->siblings.clear();
}
}
}
}
void array::ArrayDesc::init() {
strides.resize(shape.size());
size = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = size;
size *= shape[i];
}
array::ArrayDesc::ArrayDesc(
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: shape(shape),
dtype(dtype),
primitive(std::move(primitive)),
inputs(inputs) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
init();
}
array::ArrayDesc::ArrayDesc(
std::vector<int> shape,
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
std::vector<array>&& inputs)
: shape(std::move(shape)),
dtype(dtype),
status(Status::unscheduled),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
init();
}
array::ArrayDesc::~ArrayDesc() {
// When an array description is destroyed it will delete a bunch of arrays
// that may also destory their corresponding descriptions and so on and so
// forth.
//
// This calls recursively the destructor and can result in stack overflow, we
// instead put them in a vector and destroy them one at a time resulting in a
// max stack depth of 2.
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
for (array& a : inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
while (!for_deletion.empty()) {
// top is going to be deleted at the end of the block *after* the arrays
// with inputs have been moved into the vector
auto top = std::move(for_deletion.back());
for_deletion.pop_back();
for (array& a : top->inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}

View File

@@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
@@ -9,7 +8,6 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
namespace mlx::core {
@@ -33,7 +31,7 @@ class array {
template <typename It>
array(
It data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>());
@@ -43,19 +41,16 @@ class array {
/* Special case so empty lists default to float32. */
array(std::initializer_list<float> data);
/* Special case so array({}, type) is an empty array. */
array(std::initializer_list<int> data, Dtype dtype);
template <typename T>
array(
std::initializer_list<T> data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */
array(
allocator::Buffer data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
deleter_t deleter = allocator::free);
@@ -114,15 +109,6 @@ class array {
return array_desc_->strides;
};
/**
* Get the stride of the corresponding dimension.
*
* This function supports negative indexing and provides
* bounds checking. */
size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
};
/** Get the arrays data type. */
Dtype dtype() const {
return array_desc_->dtype;
@@ -135,9 +121,6 @@ class array {
template <typename T>
T item();
template <typename T>
T item() const;
struct ArrayIterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = size_t;
@@ -183,16 +166,22 @@ class array {
* API may change.
*/
array(
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
std::vector<array>&& inputs);
static std::vector<array> make_arrays(
std::vector<std::vector<int>> shapes,
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
/** A unique identifier for an array. */
@@ -209,7 +198,7 @@ class array {
allocator::Buffer buffer;
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d) {};
: buffer(buffer), d(d){};
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;
@@ -251,21 +240,11 @@ class array {
return array_desc_->inputs;
}
/** True indicates the arrays buffer is safe to reuse */
bool is_donatable() const {
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
}
/** The array's siblings. */
const std::vector<array>& siblings() const {
return array_desc_->siblings;
};
/** The array's siblings. */
std::vector<array>& siblings() {
return array_desc_->siblings;
};
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
array_desc_->position = position;
@@ -303,12 +282,6 @@ class array {
return array_desc_->data->buffer;
};
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
@@ -319,27 +292,9 @@ class array {
return static_cast<T*>(array_desc_->data_ptr);
};
enum Status { unscheduled, scheduled, available };
bool is_available() const {
return status() == Status::available;
}
const Status status() const {
return array_desc_->status;
}
void set_status(Status s) const {
array_desc_->status = s;
}
// Get the array's shared event
Event& event() const {
return array_desc_->event;
}
// Attach an event to a not yet evaluated array
void attach_event(Event e) const {
array_desc_->event = std::move(e);
// Check if the array has been evaluated
bool is_evaled() const {
return array_desc_->data != nullptr;
}
// Mark the array as a tracer array (true) or not.
@@ -367,21 +322,10 @@ class array {
void copy_shared_buffer(const array& other);
void move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_;
}
~array();
private:
// Initialize the arrays data
template <typename It>
@@ -392,12 +336,7 @@ class array {
std::vector<size_t> strides;
size_t size;
Dtype dtype;
std::shared_ptr<Primitive> primitive;
Status status;
// An event on the array used for synchronization
Event event;
std::shared_ptr<Primitive> primitive{nullptr};
// Indicates an array is being used in a graph transform
// and should not be detached from the graph
@@ -405,7 +344,7 @@ class array {
// This is a shared pointer so that *different* arrays
// can share the underlying data buffer.
std::shared_ptr<Data> data;
std::shared_ptr<Data> data{nullptr};
// Properly offset data pointer
void* data_ptr{nullptr};
@@ -425,26 +364,26 @@ class array {
// The arrays position in the output list
uint32_t position{0};
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc(
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
const std::vector<array>& inputs);
~ArrayDesc();
private:
// Initialize size, strides, and other metadata
void init();
explicit ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs);
};
// The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs
// and the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_;
// and a the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
};
template <typename T>
@@ -456,9 +395,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
template <typename It>
array::array(
It data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
init(data);
}
@@ -475,9 +414,9 @@ array::array(
template <typename T>
array::array(
std::initializer_list<T> data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
if (data.size() != size()) {
throw std::invalid_argument(
"Data size and provided shape mismatch in array construction.");
@@ -494,19 +433,6 @@ T array::item() {
return *data<T>();
}
template <typename T>
T array::item() const {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
if (status() == Status::unscheduled) {
throw std::invalid_argument(
"item() const can only be called on evaled arrays");
}
const_cast<array*>(this)->eval();
return *data<T>();
}
template <typename It>
void array::init(It src) {
set_data(allocator::malloc(size() * size_of(dtype())));
@@ -553,15 +479,4 @@ void array::init(It src) {
}
}
/* Utilities for determining whether a template parameter is array. */
template <typename T>
inline constexpr bool is_array_v =
std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
template <typename... T>
inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
template <typename... T>
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
@@ -46,14 +46,6 @@ inline void matmul_cblas_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
@@ -97,14 +89,6 @@ inline void matmul_bnns_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
const BNNSLayerParametersBroadcastMatMul gemm_params{
@@ -196,40 +180,6 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
return matmul_bnns_general(a_pre, b_pre, out);
}
template <typename T>
inline void mask_matrix(
T* data,
const bool* mask,
int tile_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str) {
int tX = (X + tile_size - 1) / tile_size;
int tY = (Y + tile_size - 1) / tile_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
if (!do_mask) {
int loc_x = i * tile_size;
int loc_y = j * tile_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
int size_x = std::min(tile_size, X - loc_x);
int size_y = std::min(tile_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
}
}
}
}
}
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -251,4 +201,4 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <cmath>
@@ -31,24 +31,16 @@ DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Less)
@@ -58,41 +50,46 @@ DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(DivMod)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, size);
} else if (in.dtype() == int32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, size);
} else if (is_unsigned(in.dtype())) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
eval(inputs, out);
unary(in, out, AbsOp());
}
}
@@ -140,8 +137,12 @@ void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -152,8 +153,12 @@ void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -164,8 +169,12 @@ void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -176,8 +185,12 @@ void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -188,40 +201,28 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
if (a.is_donatable()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
int size = a.data_size();
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -233,23 +234,30 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
if (in.flags().contiguous) {
auto allocfn = [&in, &out]() {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
};
// Use accelerate functions if possible
if (in.dtype() == float32 && out.dtype() == uint32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vfixu32(
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == float32 && out.dtype() == int32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == uint32 && out.dtype() == float32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vfltu32(
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
} else if (in.dtype() == int32 && out.dtype() == float32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
}
@@ -261,8 +269,12 @@ void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -273,8 +285,12 @@ void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -319,14 +335,57 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
// TODO: Avoid code duplication with the common backend.
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
a,
b,
out,
RemainderFn{},
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* o, auto n) {
int num_el = n;
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
});
} else {
binary(a, b, out, RemainderFn{});
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
throw std::invalid_argument(
@@ -335,19 +394,6 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpm1f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -364,8 +410,12 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
switch (base_) {
case Base::e:
vvlogf(
@@ -389,11 +439,15 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
throw std::invalid_argument(
@@ -402,6 +456,47 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x > y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x < y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@@ -431,8 +526,13 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return -x; });
}
@@ -445,13 +545,7 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
int size = a.size();
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
} else {
eval(inputs, out);
@@ -493,8 +587,12 @@ void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -505,8 +603,12 @@ void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -517,8 +619,12 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return x * x; });
@@ -529,8 +635,12 @@ void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
if (recip_) {
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
} else {
@@ -585,8 +695,12 @@ void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -597,8 +711,12 @@ void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);

View File

@@ -24,6 +24,8 @@ void _qmm_t_4_64(
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;

View File

@@ -10,65 +10,78 @@
namespace mlx::core {
namespace {
template <typename T, typename VT>
struct MinReduction {
T operator()(const T& a, const T& b) {
return std::min(a, b);
}
VT operator()(VT a, VT b) {
return simd_min(a, b);
}
};
template <typename T, typename VT>
struct MaxReduction {
T operator()(const T& a, const T& b) {
return std::max(a, b);
}
VT operator()(VT a, VT b) {
return simd_max(a, b);
}
};
template <typename T, typename VT>
struct SumReduction {
T operator()(const T& a, const T& b) {
return a + b;
}
VT operator()(VT a, VT b) {
return a + b;
}
};
template <typename T, typename VT, int N, typename Reduction>
struct StridedReduce {
void operator()(const T* x, T* accum, int size, size_t stride) {
Reduction op;
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = op((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = op(*a, *x);
a++;
x++;
}
template <typename T, typename VT, int N>
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
VT val = (*(VT*)x);
*(VT*)a += val;
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a++ += *x++;
}
}
};
}
} // namespace
// TODO: Add proper templates for the strided reduce algorithm so we don't have
// to write max/min/sum etc.
template <typename T, typename VT, int N>
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::max(*a, *x);
a++;
x++;
}
}
}
template <typename T, typename VT, int N>
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::min(*a, *x);
a++;
x++;
}
}
}
template <typename T, typename VT, int N>
void _vectorized_sum(const T* x, T* accum, int size) {
VT _sum = {0};
while (size >= N) {
_sum += (*(VT*)x);
x += N;
size -= N;
}
T sum = _sum[0];
for (int i = 1; i < N; i++) {
sum += _sum[i];
}
*accum += sum;
}
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
@@ -81,11 +94,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
0,
StridedReduce<
float,
simd_float16,
16,
SumReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_sum<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float acc;
vDSP_sve((const float*)x, 1, &acc, size);
@@ -99,11 +111,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
-std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MaxReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_max<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float max;
vDSP_maxv((const float*)x, 1, &max, size);
@@ -117,11 +128,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MinReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_min<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float min;
vDSP_minv((const float*)x, 1, &min, size);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <limits>
@@ -201,7 +201,7 @@ struct NeonFp16SimdOps {
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N>
template <typename T, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
@@ -218,21 +218,13 @@ void softmax(const array& in, array& out) {
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
size_t s = M;
while (s >= N) {
VT vals;
if constexpr (std::is_same<T, AccT>::value) {
vals = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vals[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vmaximum = ops.max(vals, vmaximum);
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = ops.reduce_max(vmaximum);
T maximum = ops.reduce_max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
maximum = std::max(maximum, *current_in_ptr);
current_in_ptr++;
}
@@ -242,29 +234,18 @@ void softmax(const array& in, array& out) {
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
VT vexp;
if constexpr (std::is_same<T, AccT>::value) {
vexp = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vexp = ops.exp(ops.sub(vexp, maximum));
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, vexp);
}
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
ops.store(current_out_ptr, vexp);
*(VT*)current_out_ptr = vexp;
vnormalizer = ops.add(vnormalizer, vexp);
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = ops.reduce_add(vnormalizer);
T normalizer = ops.reduce_add(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if (std::is_same<T, AccT>::value) {
*current_out_ptr = _exp;
}
T _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = _exp;
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
@@ -273,33 +254,14 @@ void softmax(const array& in, array& out) {
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
} else {
VT vexp;
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
for (int i = 0; i < N; ++i) {
current_out_ptr[i] = vexp[i];
}
current_in_ptr += N;
}
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
*current_out_ptr *= normalizer;
current_out_ptr++;
}
}
@@ -312,12 +274,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
@@ -346,29 +303,15 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
"Softmax is defined only for floating point types");
break;
case float32:
softmax<
float,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
in, out);
break;
case float16:
if (precise_) {
softmax<
float16_t,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
}
softmax<
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
break;
case bfloat16:
eval(inputs, out);

View File

@@ -1,75 +1,19 @@
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE)
else()
set(COMPILER ${CMAKE_CXX_COMPILER})
endif()
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${COMPILER}
${PROJECT_SOURCE_DIR}
${CLANG}
DEPENDS make_compiled_preamble.sh
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h
)
add_custom_target(
cpu_compiled_preamble
DEPENDS compiled_preamble.cpp
)
add_dependencies(mlx cpu_compiled_preamble)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)
if (IOS)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp
)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
)
endif()

View File

@@ -7,7 +7,6 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/binary_two.h"
#include "mlx/backend/common/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -74,7 +73,7 @@ void Add::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Add());
binary(a, b, out, [](auto x, auto y) { return x + y; });
}
void DivMod::eval(
@@ -136,59 +135,93 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Divide());
binary(a, b, out, [](auto x, auto y) { return x / y; });
}
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Remainder());
binary(a, b, out, RemainderFn{});
}
void Equal::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (equal_nan_) {
comparison_op(inputs[0], inputs[1], out, detail::NaNEqual());
comparison_op(inputs[0], inputs[1], out, [](auto x, auto y) {
return x == y || (std::isnan(x) && std::isnan(y));
});
} else {
comparison_op(inputs[0], inputs[1], out, detail::Equal());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x == y; });
}
}
void Greater::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x > y; });
}
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x >= y; });
}
void Less::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x < y; });
}
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x <= y; });
}
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
auto op = [](auto x, auto y) {
constexpr float inf = std::numeric_limits<float>::infinity();
auto maxval = (x > y) ? x : y;
auto minval = (x > y) ? y : x;
return (minval == -inf || maxval == inf)
? maxval
: static_cast<decltype(x)>(
maxval + std::log1p(std::exp(minval - maxval)));
};
if (is_floating_point(out.dtype())) {
if (out.dtype() == float32) {
binary_op<float>(a, b, out, op);
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, op);
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, op);
} else {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
}
} else {
throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with"
@@ -200,118 +233,65 @@ void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Maximum());
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
void Minimum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Minimum());
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
void Multiply::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Multiply());
binary(a, b, out, [](auto x, auto y) { return x * y; });
}
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x != y; });
}
struct PowerFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
return std::pow(base, exp);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
if (exp < 0) {
throw std::invalid_argument(
"Integers cannot be raise to negative powers");
}
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
};
void Power::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Power());
binary(a, b, out, PowerFn{});
}
void Subtract::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Subtract());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd());
break;
case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr());
break;
case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor());
break;
case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift());
break;
case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift());
break;
}
}
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[arctan2] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[arctan2] Cannot compute inverse tangent for arrays"
" with non floating point type.");
}
binary(a, b, out, [](auto x, auto y) { return x - y; });
}
} // namespace mlx::core

View File

@@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h"
@@ -11,7 +10,7 @@ namespace mlx::core {
namespace {
enum class BinaryOpType {
enum BinaryOpType {
ScalarScalar,
ScalarVector,
VectorScalar,
@@ -22,17 +21,17 @@ enum class BinaryOpType {
BinaryOpType get_binary_op_type(const array& a, const array& b) {
BinaryOpType bopt;
if (a.data_size() == 1 && b.data_size() == 1) {
bopt = BinaryOpType::ScalarScalar;
bopt = ScalarScalar;
} else if (a.data_size() == 1 && b.flags().contiguous) {
bopt = BinaryOpType::ScalarVector;
bopt = ScalarVector;
} else if (b.data_size() == 1 && a.flags().contiguous) {
bopt = BinaryOpType::VectorScalar;
bopt = VectorScalar;
} else if (
a.flags().row_contiguous && b.flags().row_contiguous ||
a.flags().col_contiguous && b.flags().col_contiguous) {
bopt = BinaryOpType::VectorVector;
bopt = VectorVector;
} else {
bopt = BinaryOpType::General;
bopt = General;
}
return bopt;
}
@@ -41,83 +40,29 @@ void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt,
bool donate_with_move = false) {
BinaryOpType bopt) {
switch (bopt) {
case BinaryOpType::ScalarScalar:
case ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
}
case ScalarVector:
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
break;
case BinaryOpType::VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
case VectorScalar:
case VectorVector:
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
break;
case BinaryOpType::VectorVector:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
break;
case BinaryOpType::General:
if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (
b.is_donatable() && b.flags().row_contiguous &&
b.itemsize() == out.itemsize() && b.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
case General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
break;
}
}
@@ -426,25 +371,25 @@ void binary_op(
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::ScalarScalar) {
if (bopt == ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
if (bopt == ScalarVector) {
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
if (bopt == VectorScalar) {
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
if (bopt == VectorVector) {
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
@@ -477,17 +422,17 @@ void binary_op(
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
bopt = VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
bopt = VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
bopt = ScalarVector;
dim = d;
}
@@ -497,20 +442,20 @@ void binary_op(
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General;
bopt = General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case BinaryOpType::VectorVector:
case VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
break;
case BinaryOpType::VectorScalar:
case VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
break;
case BinaryOpType::ScalarVector:
case ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
break;
default:

View File

@@ -260,14 +260,14 @@ void binary_op(
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::ScalarScalar) {
if (bopt == ScalarScalar) {
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
if (bopt == ScalarVector) {
opsv(
a.data<T>(),
b.data<T>(),
@@ -278,7 +278,7 @@ void binary_op(
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
if (bopt == VectorScalar) {
opvs(
a.data<T>(),
b.data<T>(),
@@ -289,7 +289,7 @@ void binary_op(
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
if (bopt == VectorVector) {
opvv(
a.data<T>(),
b.data<T>(),
@@ -327,17 +327,17 @@ void binary_op(
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
bopt = VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
bopt = VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
bopt = ScalarVector;
dim = d;
}
@@ -347,20 +347,20 @@ void binary_op(
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General;
bopt = General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case BinaryOpType::VectorVector:
case VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
break;
case BinaryOpType::VectorScalar:
case VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
break;
case BinaryOpType::ScalarVector:
case ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
break;
default:

View File

@@ -1,101 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
namespace {
// Delegate to the Cholesky factorization taking into account differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int spotrf_wrapper(char uplo, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1));
#else
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
} // namespace
void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
// (A)ᵀ = A
// and that a column-major lower triangular matrix is a row-major upper
// triangular matrix, so uplo is the opposite of what we would expect from
// upper
char uplo = (upper) ? 'L' : 'U';
// The decomposition is computed in place, so just copy the input to the
// output.
copy(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
float* matrix = factor.data<float>();
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info = spotrf_wrapper(uplo, matrix, N);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[cholesky] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
}
matrix += N;
}
}
}
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Cholesky::eval] only supports float32.");
}
cholesky_impl(inputs[0], output, upper_);
}
} // namespace mlx::core

View File

@@ -1,347 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (!in.flags().row_contiguous) {
// Just ensuring that inputs[0] came from the ops which would ensure the
// input is row contiguous.
throw std::runtime_error(
"AsStrided must be used with row contiguous arrays only.");
}
// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
c *= shape_[j];
}
auto flags = in.flags();
// TODO: Compute the contiguous flag in a better way cause now we are
// unnecessarily strict.
flags.contiguous = row_contiguous || col_contiguous;
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
// There is no easy way to compute the actual data size so we use out.size().
// The contiguous flag will almost certainly not be set so no code should
// rely on data_size anyway.
size_t data_size = out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
std::vector<size_t> strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {
numel *= inputs[0].shape(ax);
}
if (inverted_) {
numel = 1.0 / numel;
}
switch (out.dtype()) {
case bool_:
*out.data<bool>() = static_cast<bool>(numel);
break;
case uint8:
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
break;
case uint16:
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
break;
case uint32:
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
break;
case uint64:
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
break;
case int8:
*out.data<int8_t>() = static_cast<int8_t>(numel);
break;
case int16:
*out.data<int16_t>() = static_cast<int16_t>(numel);
break;
case int32:
*out.data<int32_t>() = static_cast<int32_t>(numel);
break;
case int64:
*out.data<int64_t>() = static_cast<int64_t>(numel);
break;
case float16:
*out.data<float16_t>() = static_cast<float16_t>(numel);
break;
case float32:
*out.data<float>() = static_cast<float>(numel);
break;
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;
}
}
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
}
// Special case for scalars
if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
}
// Firstly let's collapse all the contiguous dimensions of the input
auto [shape, _strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
std::vector<size_t> out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
int N = out.shape(i);
if (j < shape.size() && shape[j] % N == 0) {
shape[j] /= N;
out_strides.push_back(shape[j] * strides[j]);
j += (shape[j] == 1);
} else if (N == 1) {
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
out_strides.push_back(out_strides.back());
} else {
copy_necessary = true;
break;
}
}
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
array& out) {
auto flags = in.flags();
if (flags.row_contiguous) {
// For row contiguous reshapes:
// - Shallow copy the buffer
// - If reshaping into a vector (all singleton dimensions except one) it
// becomes col contiguous again.
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Split::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto compute_new_flags = [](const auto& shape,
const auto& strides,
size_t in_data_size,
auto flags) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in_data_size) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
return std::pair<decltype(flags), size_t>{flags, data_size};
};
std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < indices.size(); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
outputs[i].copy_shared_buffer(
in, in.strides(), new_flags, data_size, offset);
}
}
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
const array& in) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
copy_needed |= strides_[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void Slice::shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
return std::make_tuple(data_offset, inp_strides);
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::vector<size_t> out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
}
// Conditions for {row/col}_contiguous
// - array must be contiguous (no gaps)
// - underlying buffer size should have the same size as the array
// - cumulative product of shapes is equal to the strides (we can ignore axes
// with size == 1)
// - in the forward direction (column contiguous)
// - in the reverse direction (row contiguous)
// - vectors are both row and col contiguous (hence if both row/col are
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
f_stride *= out.shape(i);
flags.row_contiguous &=
(out_strides[ri] == b_stride || out.shape(ri) == 1);
b_stride *= out.shape(ri);
}
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -1,227 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void print_constant(std::ostream& os, const array& x) {
switch (x.dtype()) {
case float32:
return print_float_constant<float>(os, x);
case float16:
return print_float_constant<float16_t>(os, x);
case bfloat16:
return print_float_constant<bfloat16_t>(os, x);
case complex64:
return print_complex_constant<complex64_t>(os, x);
case int8:
return print_int_constant<int8_t>(os, x);
case int16:
return print_int_constant<int16_t>(os, x);
case int32:
return print_int_constant<int32_t>(os, x);
case int64:
return print_int_constant<int64_t>(os, x);
case uint8:
return print_int_constant<uint8_t>(os, x);
case uint16:
return print_int_constant<uint16_t>(os, x);
case uint32:
return print_int_constant<uint32_t>(os, x);
case uint64:
return print_int_constant<uint64_t>(os, x);
case bool_:
os << std::boolalpha << x.item<bool>();
return;
default:
throw std::runtime_error("Unsupported constant type");
}
}
std::string get_type_string(Dtype d) {
switch (d) {
case float32:
return "float";
case float16:
return "float16_t";
case bfloat16:
return "bfloat16_t";
case complex64:
return "complex64_t";
case bool_:
return "bool";
case int8:
return "int8_t";
case int16:
return "int16_t";
case int32:
return "int32_t";
case int64:
return "int64_t";
case uint8:
return "uint8_t";
case uint16:
return "uint16_t";
case uint32:
return "uint32_t";
case uint64:
return "uint64_t";
default: {
std::ostringstream msg;
msg << "Unsupported compilation type " << d;
throw std::runtime_error(msg.str());
}
}
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (auto& x : inputs) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape) {
bool contiguous = true;
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (const auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
} else if (non_scalar_inputs == 0 && !shape.empty()) {
contiguous = false;
}
return contiguous;
}
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers /* = false */) {
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Correct size
// - Not a scalar
// - Donatable
// - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o++].move_shared_buffer(in);
} else {
outputs[o++].copy_shared_buffer(in);
}
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
} else {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
}
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
}
} // namespace mlx::core

View File

@@ -1,70 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <iomanip>
#include <sstream>
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/primitives.h"
namespace mlx::core {
inline bool is_static_cast(const Primitive& p) {
return (
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids);
std::string get_type_string(Dtype d);
template <typename T>
void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< x.item<T>() << std::setprecision(old_precision);
}
template <typename T>
void print_int_constant(std::ostream& os, const array& x) {
os << x.item<T>();
}
template <typename T>
void print_complex_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
T constant = x.item<T>();
os << get_type_string(x.dtype()) << "("
<< std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< constant.real() << ", " << constant.imag() << ")"
<< std::setprecision(old_precision);
}
void print_constant(std::ostream& os, const array& x);
inline bool is_scalar(const array& x) {
return x.ndim() == 0;
}
// Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape);
// Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers = false);
} // namespace mlx::core

View File

@@ -1,356 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include <filesystem>
#include <list>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
namespace mlx::core {
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::string& source_code = "") {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
if (!lib) {
std::ostringstream msg;
msg << "Could not load C++ shared library " << dlerror();
throw std::runtime_error(msg.str());
}
}
~DLib() {
dlclose(lib);
}
void* lib;
};
// Statics to cache compiled libraries and functions
static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels;
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
if (source_code.empty()) {
return nullptr;
}
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255
// characters. Clip file name with a little extra room and append a 16
// character hash.
constexpr int max_file_name_length = 245;
if (kernel_name.size() > max_file_name_length) {
std::ostringstream file_name;
file_name
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
auto file_id = std::hash<std::string>{}(kernel_name);
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
kernel_file_name = file_name.str();
} else {
kernel_file_name = kernel_name;
}
std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_file_name << ".so";
auto shared_lib_path = get_temp_file(shared_lib_name.str());
bool lib_exists = false;
{
std::ifstream f(shared_lib_path.c_str());
lib_exists = f.good();
}
if (!lib_exists) {
// Open source file and write source code to it
std::ostringstream source_file_name;
source_file_name << kernel_file_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str());
std::ofstream source_file(source_file_path);
source_file << source_code;
source_file.close();
std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
<< source_file_path << " -o " << shared_lib_path;
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
if (return_code) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
<< " with error code " << return_code << "." << std::endl;
throw std::runtime_error(msg.str());
}
}
// load library
libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
<< kernel_name << std::endl
<< dlerror();
throw std::runtime_error(msg.str());
}
kernels.insert({kernel_name, fun});
return fun;
}
inline void build_kernel(
std::ostream& os,
const std::string& kernel_name,
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous,
int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer;
// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
// Skip constants from the input list
if (is_constant(x)) {
continue;
}
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
}
}
// Add the output arguments
for (auto& x : outputs) {
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
if (contiguous) {
os << " for (size_t i = 0; i < size; ++i) {" << std::endl;
} else {
for (int d = 0; d < ndim; ++d) {
os << " for (int i" << d << " = 0; i" << d << " < shape[" << d
<< "]; ++i" << d << ") {" << std::endl;
}
}
// Read the inputs in tmps
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ";" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[i];" << std::endl;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = *"
<< xname << ";" << std::endl;
}
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
if (contiguous) {
os << " " << namer.get_name(x) << "[i] = tmp_" << namer.get_name(x)
<< ";" << std::endl;
} else {
os << " *" << namer.get_name(x) << "++ = tmp_" << namer.get_name(x)
<< ";" << std::endl;
}
}
// Close loops
if (contiguous) {
os << " }" << std::endl;
} else {
for (int d = ndim - 1; d >= 0; --d) {
// Update pointers
for (auto& x : inputs) {
if (is_constant(x) || is_scalar(x)) {
continue;
}
auto& xname = namer.get_name(x);
os << " " << xname << " += " << xname << "_strides[" << d << "];"
<< std::endl;
if (d < ndim - 1) {
os << " " << xname << " -= " << xname << "_strides[" << d + 1 << "]"
<< " * shape[" << d + 1 << "];" << std::endl;
}
}
os << " }" << std::endl;
}
}
// Finish the kernel
os << "}" << std::endl;
}
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
bool contiguous = compiled_check_contiguity(inputs, shape);
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
std::vector<std::vector<size_t>> strides;
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
continue;
}
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
}
// Get the kernel name from the lib
int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) {
kernel_name += std::to_string(shape.size());
}
// Get the function
auto fn_ptr = compile(kernel_name);
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
build_kernel(
kernel,
kernel_name,
inputs_,
outputs_,
tape_,
constant_ids_,
contiguous,
ndim);
// Close extern "C"
kernel << "}" << std::endl;
// Compile and get function pointer
fn_ptr = compile(kernel_name, kernel.str());
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);
for (auto& x : outputs) {
args.push_back(x.data<void>());
}
if (!contiguous) {
args.push_back((void*)outputs[0].shape().data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
fun(args.data());
}
} // namespace mlx::core

View File

@@ -1,23 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
namespace mlx::core {
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is not available so check if the device is a GPU
// device.
namespace detail {
bool compile_available_for_device(const Device& device) {
return device == Device::gpu;
}
} // namespace detail
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error(
"[Compiled::eval_cpu] CPU compialtion not supported on the platform.");
}
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More