Compare commits

..

3 Commits

Author SHA1 Message Date
Angelos Katharopoulos
5b46b9bc52 Add priority-queue eval 2024-01-13 19:58:20 -08:00
Awni Hannun
fd94be28ea fix test + choose stream with slight care 2024-01-13 13:34:27 -08:00
Awni Hannun
9051fa1eaa Use a dummy primitive to only sync with one output 2024-01-13 13:08:19 -08:00
339 changed files with 10952 additions and 42421 deletions

View File

@@ -1,8 +1,5 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters:
nightly_build:
type: boolean
@@ -10,9 +7,6 @@ parameters:
weekly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
jobs:
linux_build_and_test:
@@ -31,29 +25,19 @@ jobs:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade pybind11[global]
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libblas-dev
- run:
name: Install Python package
name: Build python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
- run:
name: Generate package stubs
name: Run the python tests
command: |
echo "stubs"
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3 -m pip install .
python3 -m unittest discover python/tests
- run:
name: Build CPP only
command: |
@@ -63,192 +47,162 @@ jobs:
command: ./build/tests/tests
mac_build_and_test:
parameters:
xcode_version:
type: string
default: "15.2.0"
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.large.gen1
machine: true
resource_class: ml-explore/m-builder
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.8
python3.8 -m venv env
source env/bin/activate
pip install --upgrade pip
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=3.9
conda activate runner-env
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade pybind11[global]
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run:
name: Install Python package
name: Build python package
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
eval "$(conda shell.bash hook)"
conda activate runner-env
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
- run:
name: Generate package stubs
name: Run the python tests
command: |
source env/bin/activate
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3.11 -m pip install .
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
DEVICE=cpu ./build/tests/tests
command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
build_release:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
xcode_version:
macos_version:
type: string
default: "15.2.0"
build_env:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.large.gen1
default: "14"
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade setuptools
pip install --upgrade pybind11[global]
pip install numpy
pip install twine
pip install build
- run:
name: Install Python package
name: Build package
command: |
source env/bin/activate
DEV_RELEASE=1 \
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
PYPI_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
- run:
name: Generate package stubs
command: |
source env/bin/activate
python setup.py generate_stubs
- run:
name: Build Python package
command: |
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
source env/bin/activate
twine upload dist/*
python setup.py bdist_wheel
twine upload dist/* --repository mlx
- store_artifacts:
path: dist/
build_linux_test_release:
build_dev_release:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
extra_env:
macos_version:
type: string
default: "DEV_RELEASE=1"
docker:
- image: ubuntu:20.04
default: "14"
steps:
- checkout
- run:
name: Build wheel
name: Install dependencies
command: |
PYTHON=python<< parameters.python_version >>
apt-get update
apt-get upgrade -y
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
apt-get install -y apt-utils
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
apt-get install -y build-essential git
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade setuptools
pip install --upgrade pybind11[global]
pip install numpy
pip install auditwheel
pip install patchelf
pip install build
<< parameters.extra_env >> \
pip install twine
- run:
name: Build package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
python setup.py bdist_wheel
twine upload dist/* --repository mlx
- store_artifacts:
path: wheelhouse/
path: dist/
build_package:
machine: true
resource_class: ml-explore/m-builder
parameters:
python_version:
type: string
default: "3.9"
macos_version:
type: string
default: "14"
steps:
- checkout
- run:
name: Install dependencies
command: |
eval "$(conda shell.bash hook)"
rm -r $CONDA_PREFIX/envs/runner-env
conda create -y -n runner-env python=<< parameters.python_version >>
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install numpy
pip install twine
- run:
name: Build package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py bdist_wheel
- store_artifacts:
path: dist/
workflows:
build_and_test:
when:
and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0"]
- linux_build_and_test
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test
- build_release:
filters:
tags:
@@ -258,56 +212,20 @@ workflows:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0"]
- linux_build_and_test:
requires: [ hold ]
macos_version: ["13", "14"]
nightly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
when: << pipeline.parameters.nightly_build >>
jobs:
- build_release:
- build_package:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"]
macos_version: ["13", "14"]
weekly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.weekly_build >>
when: << pipeline.parameters.weekly_build >>
jobs:
- build_release:
- build_dev_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_linux_test_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
extra_env: ["PYPI_RELEASE=1"]
macos_version: ["13", "14"]

View File

@@ -1,15 +1,15 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.3
rev: v17.0.6
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.3.0
rev: 23.12.1
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 5.12.0
hooks:
- id: isort
args:

View File

@@ -10,12 +10,8 @@ MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
@@ -256,4 +252,4 @@ Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.

View File

@@ -15,33 +15,32 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.12.0)
set(MLX_VERSION 0.0.9)
endif()
# --------------------- Processor tests -------------------------
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
set(MLX_BUILD_ARM OFF)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC)
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
else()
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
set(MLX_BUILD_METAL OFF)
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message(WARNING
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, "
" make sure you are building for arm64.")
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
@@ -66,13 +65,9 @@ endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if (MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
add_compile_definitions(_METAL_)
# Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
@@ -80,21 +75,20 @@ elseif (MLX_BUILD_METAL)
COMMAND_ERROR_IS_FATAL ANY)
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
else()
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
endif()
FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
PATCH_COMMAND patch -N -i ${METAL_CPP_PATCH} || true
)
FetchContent_MakeAvailable(metal_cpp)
@@ -119,27 +113,7 @@ if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas.
#set(BLA_VENDOR Generic)
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
@@ -149,8 +123,8 @@ else()
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
message(STATUS ${BLAS_LIBRARIES})
message(STATUS ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
endif()
@@ -158,7 +132,7 @@ endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
mlx
mlx
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>
@@ -166,12 +140,8 @@ target_include_directories(
if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()

View File

@@ -1,4 +1,3 @@
include CMakeLists.txt
recursive-include mlx/ *
include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -6,17 +6,15 @@
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning research on Apple silicon,
brought to you by Apple machine learning research.
MLX is an array framework for machine learning on Apple silicon, brought to you
by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
- **Familiar APIs**: MLX has a Python API that closely follows NumPy.
MLX also has a fully featured C++ API, which closely mirrors the Python API.
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
that closely follow PyTorch to simplify building more complex models.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,
@@ -70,18 +68,10 @@ in the documentation.
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
**With `pip`**:
```
pip install mlx
```
**With `conda`**:
```
conda install -c conda-forge mlx
```
Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source.

View File

@@ -73,7 +73,6 @@ void time_unary_ops() {
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
@@ -85,9 +84,7 @@ void time_binary_ops() {
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(where, condition, a, b, device);
condition = array({true});
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
@@ -96,9 +93,7 @@ void time_binary_ops() {
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device);
condition = broadcast_to(array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
@@ -106,7 +101,6 @@ void time_binary_ops() {
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
}
void time_strided_ops() {

View File

@@ -17,13 +17,14 @@
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
<< std::flush << std::setprecision(5) \
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " \
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
template <typename F, typename... Args>
double time_fn(F fn, Args&&... args) {
double time_fn(F fn, Args... args) {
// warmup
for (int i = 0; i < 5; ++i) {
eval(fn(std::forward<Args>(args)...));

View File

@@ -166,13 +166,13 @@ if __name__ == "__main__":
dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
(1, 64, 64, 25344),
(16, 1024, 1024, 1024),
(1, 1024, 1024, 2048),
(4, 1024, 1024, 4096),
(4, 1024, 4096, 1024),
(1, 4096, 4096, 4096),
(15, 1023, 1023, 1023),
(17, 1025, 1025, 1025),
)
for dtype in dtypes:

View File

@@ -60,60 +60,20 @@ def matmul(x, y):
mx.eval(ys)
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
def _quant_matmul(x, w, s, b, group_size, bits):
ys = []
for i in range(10):
ys.append(
mx.quantized_matmul(
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
)
)
ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits))
mx.eval(ys)
quant_matmul = {
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
"quant_matmul_128_2": partial(
_quant_matmul, transpose=False, group_size=128, bits=2
),
"quant_matmul_128_4": partial(
_quant_matmul, transpose=False, group_size=128, bits=4
),
"quant_matmul_128_8": partial(
_quant_matmul, transpose=False, group_size=128, bits=8
),
"quant_matmul_t_32_2": partial(
_quant_matmul, transpose=True, group_size=32, bits=2
),
"quant_matmul_t_32_4": partial(
_quant_matmul, transpose=True, group_size=32, bits=4
),
"quant_matmul_t_32_8": partial(
_quant_matmul, transpose=True, group_size=32, bits=8
),
"quant_matmul_t_64_2": partial(
_quant_matmul, transpose=True, group_size=64, bits=2
),
"quant_matmul_t_64_4": partial(
_quant_matmul, transpose=True, group_size=64, bits=4
),
"quant_matmul_t_64_8": partial(
_quant_matmul, transpose=True, group_size=64, bits=8
),
"quant_matmul_t_128_2": partial(
_quant_matmul, transpose=True, group_size=128, bits=2
),
"quant_matmul_t_128_4": partial(
_quant_matmul, transpose=True, group_size=128, bits=4
),
"quant_matmul_t_128_8": partial(
_quant_matmul, transpose=True, group_size=128, bits=8
),
"quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8),
"quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2),
"quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4),
"quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8),
}
@@ -269,13 +229,6 @@ def linear(w, b, x):
mx.eval(ys)
def linear_fused(w, b, x):
ys = []
for i in range(10):
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
mx.eval(ys)
def rope(x):
*_, N, D = x.shape
ys = []
@@ -380,6 +333,10 @@ if __name__ == "__main__":
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.cpu:
mx.set_default_device(mx.cpu)
else:
@@ -402,10 +359,6 @@ if __name__ == "__main__":
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))
@@ -416,10 +369,7 @@ if __name__ == "__main__":
print(bench(quant_matmul[args.benchmark], *xs))
elif args.benchmark == "linear":
if args.fused:
print(bench(linear_fused, *xs))
else:
print(bench(linear, *xs))
print(bench(linear, *xs))
elif args.benchmark == "sum_axis":
print(bench(reduction, "sum", axis, x))

View File

@@ -331,6 +331,10 @@ if __name__ == "__main__":
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps"
@@ -350,10 +354,6 @@ if __name__ == "__main__":
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))

View File

@@ -80,8 +80,10 @@ if __name__ == "__main__":
_filter = make_predicate(args.filter, args.negative_filter)
if args.mlx_dtypes:
compare_filtered = lambda x: (
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
compare_filtered = (
lambda x: compare_mlx_dtypes(
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
)
if _filter(x)
else None
)

View File

@@ -1,109 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import math
import random
import mlx.core as mx
from time_utils import time_fn
def bench_gelu():
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
x = mx.random.uniform(shape=(1000, 1024))
def gen_fun(fun):
def bench_fun(x):
for _ in range(10):
x = fun(x)
return x
return bench_fun
time_fn(gen_fun(gelu), x, msg="fixed gelu")
time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")
def randint():
return random.randint(1, x.shape[0])
def gen_fun(fun):
def bench_fun(x, y):
x = x[: randint()]
for _ in range(10):
x = fun(x)
y = fun(y)
return x, y
return bench_fun
y = mx.random.uniform(shape=(1000, 1024))
time_fn(gen_fun(gelu), x, y, msg="variable gelu")
time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
time_fn(
gen_fun(mx.compile(gelu, shapeless=True)),
x,
y,
msg="shapeless variable gelu",
)
def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias)
def layernorm(x):
x = x.astype(mx.float32)
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + 1e-4)
x = x.astype(mx.float16)
return weight * x + bias
x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)
def gen_fun(fun):
def bench_fun(x):
for _ in range(10):
x = fun(x)
return x
return bench_fun
time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")
def randint():
return random.randint(1, x.shape[0])
def gen_fun(fun):
def bench_fun(x):
x = x[: randint()]
for _ in range(10):
x = fun(x)
return x
return bench_fun
random.seed(0)
time_fn(gen_fun(layernorm), x, msg="variable layernorm")
random.seed(0)
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
random.seed(0)
time_fn(
gen_fun(mx.compile(layernorm, shapeless=True)),
x,
msg="shapeless variable layernorm",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Compile benchmarks.")
args = parser.parse_args()
bench_gelu()
bench_layernorm()

View File

@@ -1,123 +0,0 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_1D(strides=1, padding=0, groups=1):
def mx_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_1D
def make_pt_conv_1D(strides=1, padding=0, groups=1):
@torch.no_grad()
def pt_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_1D
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
scale = 1.0 / math.sqrt(wH * C)
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_1D(strides, padding, groups)
f_pt = make_pt_conv_1D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv1d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 5, 32, 1, 2, 1),
(4, 32, 32, 5, 32, 1, 2, 2),
(4, 32, 32, 5, 32, 1, 2, 4),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 16),
(4, 32, 32, 5, 32, 1, 2, 32),
(4, 32, 256, 5, 512, 1, 2, 2),
(4, 32, 256, 5, 512, 1, 2, 128),
(4, 32, 256, 5, 512, 1, 2, 256),
)
for dtype in dtypes:
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
for N, iH, C, wH, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, iH, C, wH, O, strides, padding, np_dtype, groups
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,129 +0,0 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding)
f_pt = make_pt_conv_2D(strides, padding)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
)
for dtype in dtypes:
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
for N, H, W, C, kH, kW, O, strides, padding in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,57 +0,0 @@
# Copyright © 2024 Apple Inc.
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def bandwidth_gb(runtime_ms, system_size):
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
bytes_per_gb = 1e9
ms_per_s = 1e3
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size):
def fft(x):
out = mx.fft.fft(x)
mx.eval(out)
return out
bandwidths = []
for k in range(4, 12):
n = 2**k
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
x = x.astype(mx.complex64)
mx.eval(x)
runtime_ms = measure_runtime(fft, x=x)
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
return bandwidths
def time_fft():
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=int(2**22))
with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=int(2**29))
# plot bandwidths
x = [2**k for k in range(4, 12)]
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
plt.title("MLX FFT Benchmark")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig("fft_plot.png")
if __name__ == "__main__":
time_fft()

View File

@@ -1,53 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from time import time
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_gather_mlx(x_shape, idx_shape):
def gather(x, idx):
mx.eval(x[idx])
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
x = mx.random.normal(x_shape).astype(mx.float32)
runtime = measure_runtime(gather, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_gather_torch(x_shape, idx_shape, device):
def gather(x, idx, device):
_ = x[idx]
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
x = torch.randn(x_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
idx_shapes = [(1_000_000,), (100_000,), ()]
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_gather_mlx(x_shape, idx_shape)
benchmark_gather_torch(x_shape, idx_shape, device=device)

View File

@@ -1,41 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def layer_norm(x, w, b, eps):
ot = x.dtype
x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True)
return (x - mu) * mx.rsqrt(v + eps) * w + b
def time_layer_norm():
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
if __name__ == "__main__":
time_layer_norm()

View File

@@ -0,0 +1,198 @@
# Copyright © 2023 Apple Inc.
import math
import time
import jax
import jax.numpy as jnp
from flax import linen as nn
class RoPE(nn.Module):
dims: int
traditional: bool = False
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = jnp.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
dtype=jnp.float32,
):
D = D // 2
positions = jnp.arange(offset, N, dtype=dtype)
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)
return costheta, sintheta
@nn.compact
def __call__(self, x, offset: int = 0):
shape = x.shape
x = x.reshape((-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.reshape(shape)
class LlamaAttention(nn.Module):
dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
num_heads = self.num_heads
dims = self.dims
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = jnp.concatenate([key_cache, keys], axis=2)
values = jnp.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = jax.nn.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
dims: int
mlp_dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
dims = self.dims
mlp_dims = self.mlp_dims
num_heads = self.num_heads
self.attention = LlamaAttention(dims, num_heads, dtype)
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = jax.nn.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
dtype = jnp.float16
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
x = jax.random.normal(k1, (1, 1, D), dtype)
cache = [
jax.random.normal(k2, [1, H, C, D // H], dtype),
jax.random.normal(k3, [1, H, C, D // H], dtype),
]
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
params = layer.init(k4, x, mask=None, cache=cache)["params"]
@jax.jit
def model_fn(x, mask, cache):
return layer.apply({"params": params}, x, mask=mask, cache=cache)
T = measure(model_fn, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,118 @@
# Copyright © 2023 Apple Inc.
import math
import time
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, False)
self.key_proj = nn.Linear(dims, dims, False)
self.value_proj = nn.Linear(dims, dims, False)
self.out_proj = nn.Linear(dims, dims, False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, False)
self.linear2 = nn.Linear(dims, mlp_dims, False)
self.linear3 = nn.Linear(mlp_dims, dims, False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
mx.eval(y, c)
start = time.time()
rs = []
for i in range(5):
y, c = model(x, mask=None, cache=cache)
rs.append((y, c))
mx.eval(rs)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
mx.set_default_device(mx.gpu)
dtype = mx.float16
layer = LlamaEncoderLayer(D, F, H)
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
x = mx.random.normal([1, 1, D], dtype=dtype)
cache = [
mx.random.normal([1, H, C, D // H], dtype=dtype),
mx.random.normal([1, H, C, D // H], dtype=dtype),
]
mx.eval(x, cache)
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,199 @@
# Copyright © 2023 Apple Inc.
import math
import time
import torch
import torch.mps
import torch.nn as nn
def sync_if_needed(x):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
class RoPE(nn.Module):
def __init__(self, dims: int, traditional: bool = False):
super().__init__()
self.dims = dims
self.traditional = traditional
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
else:
rx = torch.cat([rx1, rx2], dim=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
return rx
def forward(self, x, offset: int = 0):
shape = x.shape
x = x.view(-1, shape[-2], shape[-1])
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.view(*shape)
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
device="cpu",
dtype=torch.float32,
):
D = D // 2
positions = torch.arange(offset, N, dtype=dtype, device=device)
freqs = torch.exp(
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
)
theta = positions.view(-1, 1) * freqs.view(1, -1)
costheta = torch.cos(theta)
sintheta = torch.sin(theta)
return costheta, sintheta
class RMSNorm(nn.Module):
def __init__(self, dims: int, epsilon: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones((dims,)))
self.epsilon = epsilon
def forward(self, x):
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
return self.gamma * x * n
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def forward(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = torch.cat([key_cache, keys], dim=2)
values = torch.cat([value_cache, values], dim=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = torch.softmax(scores, dim=-1)
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = RMSNorm(dims)
self.norm2 = RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def forward(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = torch.nn.functional.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
@torch.no_grad()
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
device = torch.device("mps")
dtype = torch.float16
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
x = torch.randn(1, 1, D).to(device).to(dtype)
cache = [
torch.randn(1, H, C, D // H).to(device).to(dtype),
torch.randn(1, H, C, D // H).to(device).to(dtype),
]
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -1,39 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def rms_norm(x, w, eps):
ot = x.dtype
x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (x * n).astype(ot) * w
def time_rms_norm():
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1))
g2 = mx.grad(f2, argnums=(0, 1))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x, w):
gx, gw = x, w
for _ in range(32):
gx, gw = g(gx, gw, y)
return gx, gw
time_fn(rms_norm_loop, g1, x, w)
time_fn(rms_norm_loop, g2, x, w)
time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w)
if __name__ == "__main__":
time_rms_norm()

View File

@@ -1,35 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def time_rope():
rope = nn.RoPE(64)
# vec
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x, offset=100)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):
for _ in range(32):
x = rope(x)
return x
time_fn(rope_mat, x)
if __name__ == "__main__":
time_rope()

View File

@@ -1,96 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import mlx.core as mx
import torch
from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx):
dst[*idx] = x
mx.eval(dst)
idx = []
for idx_shape in idx_shapes:
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
x = mx.random.normal(x_shape).astype(mx.float32)
dst = mx.random.normal(dst_shape).astype(mx.float32)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
print(f"MLX: {runtime:.3f}ms")
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def gather(dst, x, idx, device):
dst[*idx] = x
if device == torch.device("mps"):
torch.mps.synchronize()
idx = []
for idx_shape in idx_shapes:
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Gather benchmarks.")
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
device = torch.device("cpu")
else:
device = torch.device("mps")
dst_shapes = [
(10, 64),
(100_000, 64),
(1_000_000, 64),
(100_000,),
(2_000_00,),
(20_000_000,),
(10000, 64),
(100, 64),
(100, 10_000, 64),
(10, 100, 100, 21),
(1_000, 1_000, 10),
]
idx_shapes = [
[(1_000_000,)],
[(1_000_000,)],
[(100_000,)],
[(1_000_000,)],
[(20_000_000,)],
[(20_000_000,)],
[(1000000,)],
[(10000000,)],
[(1_000,)],
[(10_000,)],
[(1_000,), (1_000,)],
]
x_shapes = [
(1_000_000, 64),
(1_000_000, 64),
(100_000, 64),
(1_000_000,),
(20_000_000,),
(20_000_000,),
(1000000, 64),
(10000000, 64),
(1_000, 10_000, 64),
(10_000, 100, 100, 21),
(1_000, 10),
]
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@@ -44,13 +44,6 @@ def time_matmul():
time_fn(mx.matmul, a, b)
def time_maximum():
a = mx.random.uniform(shape=(32, 1024, 1024))
b = mx.random.uniform(shape=(32, 1024, 1024))
mx.eval(a, b)
time_fn(mx.maximum, a, b)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -108,7 +101,6 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_maximum()
time_exp()
time_negative()
time_logsumexp()

View File

@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
# Copyright © 2023 Apple Inc.
import time
@@ -6,11 +6,7 @@ import mlx.core as mx
def time_fn(fn, *args, **kwargs):
msg = kwargs.pop("msg", None)
if msg:
print(f"Timing {msg} ...", end=" ")
else:
print(f"Timing {fn.__name__} ...", end=" ")
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
@@ -24,15 +20,3 @@ def time_fn(fn, *args, **kwargs):
msec = 1e3 * (toc - tic) / num_iters
print(f"{msec:.5f} msec")
def measure_runtime(fn, **kwargs):
# Warmup
for _ in range(5):
fn(**kwargs)
tic = time.time()
iters = 100
for _ in range(iters):
fn(**kwargs)
return (time.time() - tic) * 1000 / iters

View File

@@ -1,36 +0,0 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
@@ -1906,6 +1906,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

View File

@@ -1,36 +0,0 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
@@ -1918,6 +1918,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

1
docs/.gitignore vendored
View File

@@ -1,3 +1,2 @@
src/python/_autosummary*/
src/python/nn/_autosummary*/
src/python/optimizers/_autosummary*/

View File

@@ -1,50 +0,0 @@
################################################################################
# Primary project setup. #
################################################################################
PROJECT_NAME = "MLX"
OUTPUT_DIRECTORY = build
XML_OUTPUT = xml
HTML_OUTPUT = html
STRIP_FROM_PATH = ../
INPUT = ../mlx
FILE_PATTERNS = *.h
EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = YES
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES
################################################################################
# Doxygen preprocessor / parser control. #
################################################################################
ENABLE_PREPROCESSING = YES
MACRO_EXPANSION = YES
EXPAND_ONLY_PREDEF = NO
SKIP_FUNCTION_MACROS = NO
################################################################################
# Compound extraction control. #
################################################################################
EXTRACT_ALL = YES
EXTRACT_PACKAGE = YES
EXTRACT_STATIC = YES
CASE_SENSE_NAMES = NO
################################################################################
# Docstring control / customization. #
################################################################################
JAVADOC_AUTOBRIEF = YES
################################################################################
# Warning suppression. #
################################################################################
QUIET = YES
WARN_IF_UNDOCUMENTED = NO

View File

@@ -2,16 +2,12 @@
### Setup (do once)
Install Doxygen:
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
for example with `conda`:
```
brew install doxygen
```
Install Python packages:
```
pip install -r requirements.txt
conda install sphinx
pip install sphinx-book-theme
```
### Build
@@ -19,7 +15,7 @@ pip install -r requirements.txt
Build the docs from `mlx/docs/`
```
doxygen && make html
make html
```
View the docs by running a server in `mlx/docs/build/html/`:

View File

@@ -1,3 +0,0 @@
sphinx
breathe
sphinx-book-theme

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 746 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 76 KiB

After

Width:  |  Height:  |  Size: 7.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 48 KiB

View File

@@ -4,17 +4,16 @@
.. autoclass:: {{ objname }}
{% block methods %}
{#{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != "__init__" %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}
{% endblock %}#}

View File

@@ -5,15 +5,13 @@
import os
import subprocess
import mlx.core as mx
# -- Project information -----------------------------------------------------
project = "MLX"
copyright = "2023, MLX Contributors"
author = "MLX Contributors"
version = ".".join(mx.__version__.split(".")[:3])
release = version
version = "0.0.9"
release = "0.0.9"
# -- General configuration ---------------------------------------------------
@@ -22,28 +20,22 @@ extensions = [
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"breathe",
]
python_use_unqualified_type_names = True
autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"https://docs.python.org/3": None,
"https://numpy.org/doc/stable/": None,
}
breathe_projects = {"mlx": "../build/xml"}
breathe_default_project = "mlx"
templates_path = ["_templates"]
html_static_path = ["_static"]
source_suffix = ".rst"
main_doc = "index"
master_doc = "index"
highlight_language = "python"
pygments_style = "sphinx"
add_module_names = False
# -- Options for HTML output -------------------------------------------------
@@ -54,32 +46,11 @@ html_theme_options = {
"repository_url": "https://github.com/ml-explore/mlx",
"use_repository_button": True,
"navigation_with_keys": False,
"logo": {
"image_light": "_static/mlx_logo.png",
"image_dark": "_static/mlx_logo_dark.png",
},
}
html_logo = "_static/mlx_logo.png"
# -- Options for HTMLHelp output ---------------------------------------------
htmlhelp_basename = "mlx_doc"
def setup(app):
from sphinx.util import inspect
wrapped_isfunc = inspect.isfunction
def isfunc(obj):
type_name = str(type(obj))
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
return True
return wrapped_isfunc(obj)
inspect.isfunction = isfunc
# -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]

View File

@@ -3,5 +3,4 @@
Operations
==========
.. doxygengroup:: ops
:content-only:

View File

@@ -1,16 +1,24 @@
Developer Documentation
=======================
You can extend MLX with custom operations on the CPU or GPU. This guide
explains how to do that with a simple example.
MLX provides a open and flexible backend to which users may add operations
and specialized implementations without much hassle. While the library supplies
efficient operations that can be used and composed for any number of
applications, there may arise cases where new functionalities or highly
optimized implementations are needed. For such cases, you may design and
implement your own operations that link to and build on top of :mod:`mlx.core`.
We will introduce the inner-workings of MLX and go over a simple example to
learn the steps involved in adding new operations to MLX with your own CPU
and GPU implementations.
Introducing the Example
-----------------------
Let's say you would like an operation that takes in two arrays, ``x`` and
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
and then adds them together to get the result ``z = alpha * x + beta * y``.
You can do that in MLX directly:
Let's say that you would like an operation that takes in two arrays,
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
respectively, and then adds them together to get the result
``z = alpha * x + beta * y``. Well, you can very easily do that by just
writing out a function as follows:
.. code-block:: python
@@ -19,35 +27,44 @@ You can do that in MLX directly:
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
This function performs that operation while leaving the implementation and
function transformations to MLX.
This function performs that operation while leaving the implementations and
differentiation to MLX.
However you may need to customize the underlying implementation, perhaps to
make it faster or for custom differentiation. In this tutorial we will go
through adding custom extensions. It will cover:
However, you work with vector math libraries often and realize that the
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
You would really like the part of your applications that does this operation
on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.
* The structure of the MLX library.
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
* Implementing a GPU operation using metal.
* Adding the ``vjp`` and ``jvp`` function transformation.
* Building a custom extension and binding it to python.
Well, what a coincidence! You are in the right place. Over the course of this
example, we will learn:
* The structure of the MLX library from the frontend API to the backend implementations.
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
* How to implement your own GPU implementation using metal.
* How to add your own ``vjp`` and ``jvp``.
* How to build your implementations, link them to MLX, and bind them to python.
Operations and Primitives
-------------------------
Operations in MLX build the computation graph. Primitives provide the rules for
evaluating and transforming the graph. Let's start by discussing operations in
more detail.
In one sentence, operations in MLX build the computation graph, and primitives
provide the rules for evaluation and transformations of said graph. Let's start
by discussing operations in more detail.
Operations
^^^^^^^^^^^
Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
Operations are the frontend functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
operations in the Python API (:ref:`ops`).
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++:
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
C++ API:
.. code-block:: C++
@@ -66,7 +83,10 @@ C++:
StreamOrDevice s = {} // Stream on which to schedule the operation
);
The simplest way to this operation is in terms of existing operations:
This operation itself can call other operations within it if needed. So, the
simplest way to go about implementing this operation would be do so in terms
of existing operations.
.. code-block:: C++
@@ -80,23 +100,25 @@ The simplest way to this operation is in terms of existing operations:
// Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s);
// Add and return
return add(ax, by, s);
}
The operations themselves do not contain the implementations that act on the
data, nor do they contain the rules of transformations. Rather, they are an
easy to use interface that use :class:`Primitive` building blocks.
However, as we discussed earlier, this is not our goal. The operations themselves
do not contain the implementations that act on the data, nor do they contain the
rules of transformations. Rather, they are an easy to use interface that build
on top of the building blocks we call :class:`Primitive`.
Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
more concrete:
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create an output given a set of input :class:`array` . Further,
a :class:`Primitive` is a class that contains rules on how it is evaluated
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
``jvp``. These words on their own can be a bit abstract, so lets take a step
back and go to our example to give ourselves a more concrete image.
.. code-block:: C++
@@ -112,15 +134,11 @@ more concrete:
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
array jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
@@ -129,8 +147,7 @@ more concrete:
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
const std::vector<int>& argnums) override;
/**
* The primitive must know how to vectorize itself across
@@ -138,7 +155,7 @@ more concrete:
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
@@ -158,22 +175,22 @@ more concrete:
void eval(const std::vector<array>& inputs, array& out);
};
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
implementations of how the output array is produced given the inputs through
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
:meth:`Axpby::vmap`.
The :class:`Axpby` class derives from the base :class:`Primitive` class and
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
``beta`` as parameters. It then provides implementations of how the array ``out``
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
Using the Primitive
^^^^^^^^^^^^^^^^^^^
Using the Primitives
^^^^^^^^^^^^^^^^^^^^^
Operations can use this :class:`Primitive` to add a new :class:`array` to the
computation graph. An :class:`array` can be constructed by providing its data
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
inputs that are passed to the primitive.
Operations can use this :class:`Primitive` to add a new :class:`array` to
the computation graph. An :class:`array` can be constructed by providing its
data type, shape, the :class:`Primitive` that computes it, and the
:class:`array` inputs that are passed to the primitive.
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
.. code-block:: C++
@@ -206,7 +223,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
std::make_unique<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
@@ -221,26 +238,27 @@ This operation now handles the following:
Implementing the Primitive
--------------------------
No computation happens when we call the operation alone. The operation only
builds the computation graph. When we evaluate the output array, MLX schedules
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
No computation happens when we call the operation alone. In effect, the
operation only builds the computation graph. When we evaluate the output
array, MLX schedules the execution of the computation graph, and calls
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
stream/device specified by the user.
.. warning::
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
no memory has been allocated for the output array. It falls on the implementation
of these functions to allocate memory as needed.
of these functions to allocate memory as needed
Implementing the CPU Back-end
Implementing the CPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by implementing a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Let's start by trying to implement a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
point-wise. This is captured in the templated function :meth:`axpby_impl`.
Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
pointwise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++
@@ -278,19 +296,19 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
}
}
Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
``complex64``. We throw an error if we encounter an unexpected type.
Now, we would like our implementation to be able to do this pointwise operation
for all incoming floating point arrays. Accordingly, we add dispatches for
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
if we encounter an unexpected type.
.. code-block:: C++
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
@@ -303,26 +321,28 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"[Axpby] Only supports floating point types.");
"Axpby is only supported for floating point types.");
}
}
This is good as a fallback implementation. We can use the ``axpby`` routine
provided by the Accelerate_ framework for a faster implementation in certain
cases:
We have a fallback implementation! Now, to do what we are really here to do.
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
framework? Well, there are 3 complications to keep in mind:
#. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only use it for ``float32`` types.
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
elements have fixed strides between them. We only direct to Accelerate
if both ``x`` and ``y`` are row contiguous or column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
MLX expects to write the output to a new array. We must copy the elements
of ``y`` into the output and use that as an input to ``axpby``.
floats. We can only direct to it for ``float32`` types
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
have fixed strides between them. Possibly due to broadcasts and transposes,
we aren't guaranteed that the inputs fit this requirement. We can
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
MLX expects to write out the answer to a new array. We must copy the elements
of ``y`` into the output array and use that as an input to ``axpby``
Let's write an implementation that uses Accelerate in the right conditions.
It allocates data for the output, copies ``y`` into it, and then calls the
:func:`catlas_saxpby` from accelerate.
Let's write out an implementation that uses Accelerate in the right conditions.
It must simply allocate data for the output, copy elements of ``y`` into it,
and then call the :meth:`catlas_saxpby` from accelerate.
.. code-block:: C++
@@ -336,7 +356,17 @@ It allocates data for the output, copies ``y`` into it, and then calls the
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
@@ -359,20 +389,18 @@ It allocates data for the output, copies ``y`` into it, and then calls the
/* INCY = */ 1);
}
For inputs that do not fit the criteria for accelerate, we fall back to
:meth:`Axpby::eval`. With this in mind, let's finish our
:meth:`Axpby::eval_cpu`.
Great! But what about the inputs that do not fit the criteria for accelerate?
Luckily, we can always just direct back to :meth:`Axpby::eval`.
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
.. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
@@ -382,33 +410,35 @@ For inputs that do not fit the criteria for accelerate, we fall back to
return;
}
// Fall back to common back-end if specializations are not available
eval(inputs, outputs);
// Fall back to common backend if specializations are not available
eval(inputs, out);
}
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
We have now hit a milestone! Just this much is enough to run the operation
:meth:`axpby` on a CPU stream!
Implementing the GPU Back-end
If you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
Implementing the GPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Apple silicon devices address their GPUs using the Metal_ shading language, and
GPU kernels in MLX are written using Metal.
Apple silicon devices address their GPUs using the Metal_ shading language, and
all GPU kernels in MLX are written using metal.
.. note::
Here are some helpful resources if you are new to Metal:
Here are some helpful resources if you are new to metal!
* A walkthrough of the metal compute pipeline: `Metal Example`_
* Documentation for metal shading language: `Metal Specification`_
* Using metal from C++: `Metal-cpp`_
Let's keep the GPU kernel simple. We will launch exactly as many threads as
there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the point-wise operation, and update its assigned
element in the output.
Let's keep the GPU algorithm simple. We will launch exactly as many threads
as there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
element in the output.
.. code-block:: C++
@@ -427,14 +457,15 @@ element in the output.
// Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output
out[index] =
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
We then need to instantiate this template for all floating point types and give
each instantiation a unique host name so we can identify it.
each instantiation a unique host name so we can identify the right kernel for
each data type.
.. code-block:: C++
@@ -457,21 +488,29 @@ each instantiation a unique host name so we can identify it.
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
will see later in :ref:`Building with CMake`. In the following example, we
assume that the library ``mlx_ext.metallib`` will always be co-located with
the executable/ shared-library calling the :meth:`register_library` function.
The :meth:`register_library` function takes the library's name and potential
path (or in this case, a function that can produce the path of the metal
library) and tries to load that library if it hasn't already been registered
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
it is important to package your C++ library with the metal library. We will
go over this process in more detail later.
The logic to determine the kernel, set the inputs, resolve the grid dimensions
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
below.
.. code-block:: C++
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
@@ -479,10 +518,10 @@ below.
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
// Allocate output memory
// Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel
// Resolve name of kernel (corresponds to axpby.metal)
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
@@ -513,7 +552,7 @@ below.
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim
// Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
@@ -536,25 +575,28 @@ below.
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
A few things to note about MLX and Metal before moving on. MLX keeps track of
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
associated. We rely on :meth:`d.get_command_encoder` to give us the active
metal compute command encoder instead of building a new one and calling
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
pipelines) to the active command buffer until some specified limit is hit or
the command buffer needs to be flushed for synchronization.
A few things to note about MLX and metal before moving on. MLX keeps track
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
to give us the active metal compute command encoder instead of building a
new one and calling :meth:`compute_encoder->end_encoding` at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
for synchronization. MLX also handles enqueuing and committing the associated
command buffers as needed. We suggest taking a deeper dive into
:class:`metal::Device` if you would like to study this routine further.
Primitive Transforms
^^^^^^^^^^^^^^^^^^^^^
Next, let's add implementations for transformations in a :class:`Primitive`.
These transformations can be built on top of other operations, including the
one we just defined:
Now that we have come this far, let's also learn how to add implementations to
transformations in a :class:`Primitive`. These transformations can be built on
top of our operations, including the one we just defined now. Which then gives
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
.. code-block:: C++
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
array Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
@@ -569,12 +611,12 @@ one we just defined:
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
return multiply(scale_arr, tangents[0], stream());
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
}
}
@@ -583,35 +625,34 @@ one we just defined:
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<int>& /* unused */) {
const array& cotan,
const std::vector<int>& argnums) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
auto scale_arr = array(scale, cotan.dtype());
vjps.push_back(multiply(scale_arr, cotan, stream()));
}
return vjps;
}
Note, a transformation does not need to be fully defined to start using
the :class:`Primitive`.
Finally, you need not have a transformation fully defined to start using your
own :class:`Primitive`.
.. code-block:: C++
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("[Axpby] vmap not implemented.");
throw std::runtime_error("Axpby has no vmap implementation.");
}
Building and Binding
--------------------
Let's look at the overall directory structure first.
Let's look at the overall directory structure first.
| extensions
| ├── axpby
@@ -625,39 +666,40 @@ Let's look at the overall directory structure first.
| └── setup.py
* ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated Python package
* ``extensions/bindings.cpp`` provides Python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
Python bindings
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated python package
* ``extensions/bindings.cpp`` provides python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
python bindings
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
the Python package
the python package
Binding to Python
^^^^^^^^^^^^^^^^^^
We use nanobind_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple.
We use PyBind11_ to build a Python API for the C++ library. Since bindings
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
are already provided, adding our :meth:`axpby` becomes very simple!
.. code-block:: C++
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
@@ -669,17 +711,17 @@ already provided, adding our :meth:`axpby` is simple.
Returns:
array: ``alpha * x + beta * y``
)");
)pbdoc");
}
Most of the complexity in the above example comes from additional bells and
Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings.
.. warning::
:mod:`mlx.core` must be imported before importing
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
ensure that the casters for :mod:`mlx.core` components like
:mod:`mlx.core` needs to be imported before importing
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
ensure that the casters for :mod:`mlx.core` components like
:class:`mlx.core.array` are available.
.. _Building with CMake:
@@ -687,8 +729,8 @@ whistles such as the literal names and doc-strings.
Building with CMake
^^^^^^^^^^^^^^^^^^^^
Building the C++ extension library only requires that you ``find_package(MLX
CONFIG)`` and then link it to your library.
Building the C++ extension library itself is simple, it only requires that you
``find_package(MLX CONFIG)`` and then link it to your library.
.. code-block:: cmake
@@ -710,12 +752,12 @@ CONFIG)`` and then link it to your library.
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
We also need to build the attached Metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
We also need to build the attached metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
Here is what that looks like in practice:
Here is what that looks like in practice!
.. code-block:: cmake
@@ -737,29 +779,27 @@ Here is what that looks like in practice:
endif()
Finally, we build the nanobind_ bindings
Finally, we build the Pybind11_ bindings
.. code-block:: cmake
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
pybind11_add_module(
mlx_sample_extensions
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()
Building with ``setuptools``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once we have set out the CMake build rules as described above, we can use the
build utilities defined in :mod:`mlx.extension`:
build utilities defined in :mod:`mlx.extension` for a simple build process.
.. code-block:: python
.. code-block:: python
from mlx import extension
from setuptools import setup
@@ -769,50 +809,48 @@ build utilities defined in :mod:`mlx.extension`:
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev":[]},
packages = ["mlx_sample_extensions"],
package_dir = {"": "mlx_sample_extensions"},
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
zip_safe=False,
python_requires=">=3.8",
python_requires=">=3.7",
)
.. note::
We treat ``extensions/mlx_sample_extensions`` as the package directory
even though it only contains a ``__init__.py`` to ensure the following:
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
To build the package, first install the build dependencies with ``pip install
-r requirements.txt``. You can then build inplace for development using
You can build inplace for development using
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
This results in the directory structure:
This will result in a directory structure as follows:
| extensions
| ├── mlx_sample_extensions
| │ ├── __init__.py
| │ ├── libmlx_ext.dylib # C++ extension library
| │ ├── mlx_ext.metallib # Metal library
| │ └── _ext.cpython-3x-darwin.so # Python Binding
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
| ...
When you try to install using the command ``python -m pip install .`` (in
``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
copied along with the Python binding since they are specified as
``package_data``.
When you try to install using the command ``python -m pip install .``
(in ``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
copied along with the python binding since they are specified as ``package_data``.
Usage
-----
After installing the extension as described above, you should be able to simply
import the Python package and play with it as you would any other MLX operation.
After installing the extension as described above, you should be able to simply
import the python package and play with it as you would any other MLX operation!
Let's look at a simple script and its results:
Let's looks at a simple script and it's results!
.. code-block:: python
@@ -836,12 +874,12 @@ Output:
c correctness: True
Results
^^^^^^^
^^^^^^^^^^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we first defined on the CPU.
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we defined at first on the CPU.
.. code-block:: python
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
@@ -860,7 +898,7 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
alpha = 4.0
beta = 2.0
mx.eval(x, y)
mx.eval((x, y))
def bench(f):
# Warm up
@@ -881,23 +919,30 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
modest improvements right away!
Results:
This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`.
.. code-block::
Simple axpby: 0.114 s | Custom axpby: 0.109 s
We see some modest improvements right away!
This operation is now good to be used to build other operations,
in :class:`mlx.nn.Module` calls, and also as a part of graph
transformations such as :meth:`grad` and :meth:`simplify`!
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
The full example code is available in `mlx-examples <code>`_.
.. code: `TODO_LINK/extensions`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
.. _nanobind: https://nanobind.readthedocs.io/en/latest/
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/

View File

@@ -1,68 +0,0 @@
Metal Debugger
==============
.. currentmodule:: mlx.core
Profiling is a key step for performance optimization. You can build MLX with
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
* Records source during Metal compilation, for later inspection while
debugging.
* Labels Metal objects such as command queues, improving capture readability.
To build with debugging enabled in Python prepend
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
work.
.. note::
To capture a GPU trace you must run the application with
``MTL_CAPTURE_ENABLED=1``.
.. code-block:: python
import mlx.core as mx
a = mx.random.uniform(shape=(512, 512))
b = mx.random.uniform(shape=(512, 512))
mx.eval(a, b)
trace_file = "mlx_trace.gputrace"
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
# that the path trace_file does not already exist.
mx.metal.start_capture(trace_file)
for _ in range(10):
mx.eval(mx.add(a, b))
mx.metal.stop_capture()
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
has a great overview of all operations. Checkout the `Metal debugger
documentation`_ for more information.
.. image:: ../_static/metal_debugger/capture.png
:class: dark-light
Xcode Workflow
--------------
You can skip saving to a path by running within Xcode. First, generate an
Xcode project using CMake.
.. code-block::
mkdir build && cd build
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
open mlx.xcodeproj
Select the ``metal_capture`` example schema and run.
.. image:: ../_static/metal_debugger/schema.png
:class: dark-light
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger

View File

@@ -40,8 +40,6 @@ are the CPU and GPU.
usage/unified_memory
usage/indexing
usage/saving_and_loading
usage/function_transforms
usage/compile
usage/numpy
usage/using_streams
@@ -58,15 +56,12 @@ are the CPU and GPU.
:maxdepth: 1
python/array
python/data_types
python/devices_and_streams
python/ops
python/random
python/transforms
python/fast
python/fft
python/linalg
python/metal
python/nn
python/optimizers
python/tree_utils
@@ -82,4 +77,3 @@ are the CPU and GPU.
:maxdepth: 1
dev/extensions
dev/metal_debugger

View File

@@ -15,10 +15,10 @@ To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.8
- macOS >= 13.5
- macOS >= 13.3
.. note::
MLX is only available on devices running macOS >= 13.5
MLX is only available on devices running macOS >= 13.3
It is highly recommended to use macOS 14 (Sonoma)
@@ -54,7 +54,7 @@ Build Requirements
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- Xcode >= 15.0 and macOS SDK >= 14.0
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
.. note::
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
@@ -70,13 +70,16 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
.. code-block:: shell
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install "pybind11[global]"
conda install pybind11
brew install pybind11
Then simply build and install MLX using pip:
Then simply build and install it using pip:
.. code-block:: shell
@@ -120,7 +123,7 @@ Create a build directory and run CMake and make:
.. code-block:: shell
mkdir -p build && cd build
cmake .. && make -j
cmake .. && make -j
Run tests with:
@@ -139,7 +142,7 @@ directory as the executable statically linked to ``libmlx.a`` or the
preprocessor constant ``METAL_PATH`` should be defined at build time and it
should point to the path to the built metal library.
.. list-table:: Build Options
.. list-table:: Build Options
:widths: 25 8
:header-rows: 1
@@ -155,21 +158,19 @@ should point to the path to the built metal library.
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
* - MLX_METAL_DEBUG
- OFF
.. note::
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
.. code-block:: shell
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which
Further, you can use the following command to find out which
macOS SDK will be used
.. code-block:: shell
@@ -201,7 +202,7 @@ Then set the active developer directory:
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
x86 Shell
x86 Shell
~~~~~~~~~
.. _build shell:

View File

@@ -10,38 +10,27 @@ Array
array
array.astype
array.at
array.item
array.tolist
array.dtype
array.itemsize
array.nbytes
array.ndim
array.shape
array.size
Dtype
array.abs
array.all
array.any
array.argmax
array.argmin
array.cos
array.cummax
array.cummin
array.cumprod
array.cumsum
array.diag
array.diagonal
array.dtype
array.exp
array.flatten
array.log
array.log10
array.log1p
array.log2
array.logsumexp
array.max
array.mean
array.min
array.moveaxis
array.prod
array.reciprocal
array.reshape
@@ -51,8 +40,6 @@ Array
array.split
array.sqrt
array.square
array.squeeze
array.swapaxes
array.sum
array.transpose
array.T

View File

@@ -1,5 +1,7 @@
.. _data_types:
:orphan:
Data Types
==========
@@ -42,27 +44,9 @@ The default floating point type is ``float32`` and the default integer type is
* - ``int64``
- 8
- 64-bit signed integer
* - ``bfloat16``
- 2
- 16-bit brain float (e8, m7)
* - ``float16``
- 2
- 16-bit IEEE float (e5, m10)
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
* - ``float32``
- 4
- 32-bit float
* - ``complex64``
- 8
- 64-bit complex float
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category.
.. autosummary::
:toctree: _autosummary
Dtype
DtypeCategory
issubdtype

View File

@@ -9,11 +9,9 @@ Devices and Streams
:toctree: _autosummary
Device
Stream
default_device
set_default_device
Stream
default_stream
new_stream
set_default_stream
stream
synchronize

View File

@@ -1,14 +0,0 @@
.. _fast:
Fast
====
.. currentmodule:: mlx.core.fast
.. autosummary::
:toctree: _autosummary
rms_norm
layer_norm
rope
scaled_dot_product_attention

View File

@@ -9,4 +9,3 @@ Linear Algebra
:toctree: _autosummary
norm
qr

View File

@@ -1,17 +0,0 @@
Metal
=====
.. currentmodule:: mlx.core.metal
.. autosummary::
:toctree: _autosummary
is_available
get_active_memory
get_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
clear_cache
start_capture
stop_capture

View File

@@ -173,7 +173,6 @@ In detail:
:toctree: _autosummary
value_and_grad
quantize
.. toctree::
@@ -181,4 +180,3 @@ In detail:
nn/layers
nn/functions
nn/losses
nn/init

View File

@@ -12,24 +12,12 @@ simple functions.
:toctree: _autosummary_functions
:template: nn-module-template.rst
elu
gelu
gelu_approx
gelu_fast_approx
glu
hardswish
leaky_relu
log_sigmoid
log_softmax
mish
prelu
relu
relu6
selu
sigmoid
prelu
silu
softmax
softplus
softshrink
step
tanh
selu
mish

View File

@@ -1,45 +0,0 @@
.. _init:
.. currentmodule:: mlx.nn.init
Initializers
------------
The ``mlx.nn.init`` package contains commonly used initializers for neural
network parameters. Initializers return a function which can be applied to any
input :obj:`mlx.core.array` to produce an initialized output.
For example:
.. code:: python
import mlx.core as mx
import mlx.nn as nn
init_fn = nn.init.uniform()
# Produces a [2, 2] uniform matrix
param = init_fn(mx.zeros((2, 2)))
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
distribution, you can do:
.. code:: python
import mlx.nn as nn
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
init_fn = nn.init.uniform(low=-0.1, high=0.1)
model.apply(init_fn)
.. autosummary::
:toctree: _autosummary
constant
normal
uniform
identity
glorot_normal
glorot_uniform
he_normal
he_uniform

View File

@@ -9,39 +9,29 @@ Layers
:toctree: _autosummary
:template: nn-module-template.rst
ALiBi
AvgPool1d
AvgPool2d
BatchNorm
Sequential
ReLU
PReLU
GELU
SiLU
Step
SELU
Mish
Embedding
Linear
QuantizedLinear
Conv1d
Conv2d
BatchNorm
LayerNorm
RMSNorm
GroupNorm
InstanceNorm
Dropout
Dropout2d
Dropout3d
Embedding
GELU
GroupNorm
GRU
InstanceNorm
LayerNorm
Linear
LSTM
MaxPool1d
MaxPool2d
Mish
MultiHeadAttention
PReLU
QuantizedEmbedding
QuantizedLinear
RMSNorm
ReLU
RNN
RoPE
SELU
Sequential
SiLU
SinusoidalPositionalEncoding
Softshrink
Step
Transformer
Upsample
MultiHeadAttention
ALiBi
RoPE
SinusoidalPositionalEncoding

View File

@@ -10,16 +10,14 @@ Loss Functions
:template: nn-module-template.rst
binary_cross_entropy
cosine_similarity_loss
cross_entropy
gaussian_nll_loss
hinge_loss
huber_loss
kl_div_loss
l1_loss
log_cosh_loss
margin_ranking_loss
mse_loss
nll_loss
smooth_l1_loss
triplet_loss
triplet_loss
hinge_loss
huber_loss
log_cosh_loss
cosine_similarity_loss

View File

@@ -11,7 +11,6 @@ Module
:toctree: _autosummary
Module.training
Module.state
.. rubric:: Methods
@@ -30,7 +29,6 @@ Module
Module.named_modules
Module.parameters
Module.save_weights
Module.set_dtype
Module.train
Module.trainable_parameters
Module.unfreeze

View File

@@ -5,13 +5,13 @@ Operations
.. currentmodule:: mlx.core
.. autosummary::
.. autosummary::
:toctree: _autosummary
abs
add
all
allclose
allclose
any
arange
arccos
@@ -25,13 +25,6 @@ Operations
argpartition
argsort
array_equal
atleast_1d
atleast_2d
atleast_3d
bitwise_and
bitwise_or
bitwise_xor
block_masked_mm
broadcast_to
ceil
clip
@@ -39,24 +32,15 @@ Operations
convolve
conv1d
conv2d
conv_general
cos
cosh
cummax
cummin
cumprod
cumsum
degrees
dequantize
diag
diagonal
divide
divmod
equal
erf
erfinv
exp
expm1
expand_dims
eye
flatten
@@ -67,12 +51,7 @@ Operations
greater_equal
identity
inner
isclose
isinf
isnan
isneginf
isposinf
left_shift
less
less_equal
linspace
@@ -90,13 +69,11 @@ Operations
max
maximum
mean
meshgrid
min
minimum
moveaxis
multiply
negative
not_equal
ones
ones_like
outer
@@ -105,11 +82,9 @@ Operations
prod
quantize
quantized_matmul
radians
reciprocal
repeat
reshape
right_shift
round
rsqrt
save
@@ -128,7 +103,6 @@ Operations
square
squeeze
stack
std
stop_gradient
subtract
sum
@@ -138,8 +112,6 @@ Operations
tan
tanh
tensordot
tile
topk
transpose
tri
tril

View File

@@ -29,8 +29,19 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
.. toctree::
.. currentmodule:: mlx.optimizers
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
OptimizerState
Optimizer
SGD
RMSprop
Adagrad
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@@ -1,20 +0,0 @@
.. _common_optimizers:
Common Optimizers
=================
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
:template: optimizers-template.rst
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW
Adamax
Lion

View File

@@ -1,23 +0,0 @@
Optimizer
=========
.. currentmodule:: mlx.optimizers
.. autoclass:: Optimizer
.. rubric:: Attributes
.. autosummary::
:toctree: _autosummary
Optimizer.state
.. rubric:: Methods
.. autosummary::
:toctree: _autosummary
Optimizer.apply_gradients
Optimizer.init
Optimizer.update

View File

@@ -1,15 +0,0 @@
.. _schedulers:
Schedulers
==========
.. currentmodule:: mlx.optimizers
.. autosummary::
:toctree: _autosummary
cosine_decay
exponential_decay
join_schedules
linear_schedule
step_decay

View File

@@ -33,14 +33,13 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
.. autosummary::
:toctree: _autosummary
seed
key
split
bernoulli
categorical
gumbel
key
normal
multivariate_normal
randint
seed
split
truncated_normal
uniform
truncated_normal

View File

@@ -9,11 +9,9 @@ Transforms
:toctree: _autosummary
eval
compile
disable_compile
enable_compile
grad
value_and_grad
jvp
vjp
vmap
simplify

View File

@@ -19,4 +19,3 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_flatten
tree_unflatten
tree_map
tree_map_with_path

View File

@@ -1,430 +0,0 @@
.. _compile:
Compilation
===========
.. currentmodule:: mlx.core
MLX has a :func:`compile` function transformation which compiles computation
graphs. Function compilation results in smaller graphs by merging common work
and fusing certain operations. In many cases this can lead to big improvements
in run-time and memory use.
Getting started with :func:`compile` is simple, but there are some edge cases
that are good to be aware of for more complex graphs and advanced usage.
Basics of Compile
-----------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
# Regular call, no compilation
# Prints: array(2.36788, dtype=float32)
print(fun(x, y))
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
function multiple times will not initiate a new compilation. This means you
should typically compile functions that you plan to use more than once.
.. code-block:: python
def fun(x, y):
return mx.exp(-x) + y
x = mx.array(1.0)
y = mx.array(2.0)
compiled_fun = mx.compile(fun)
# Compiled here
compiled_fun(x, y)
# Not compiled again
compiled_fun(x, y)
# Not compiled again
mx.compile(fun)(x, y)
There are some important cases to be aware of that can cause a function to
be recompiled:
* Changing the shape or number of dimensions
* Changing the type of any of the inputs
* Changing the number of inputs to the function
In certain cases only some of the compilation stack will be rerun (for
example when changing the shapes) and in other cases the full compilation
stack will be rerun (for example when changing the types). In general you
should avoid compiling functions too frequently.
Another idiom to watch out for is compiling functions which get created and
destroyed frequently. This can happen, for example, when compiling an anonymous
function in a loop:
.. code-block:: python
a = mx.array(1.0)
# Don't do this, compiles lambda at each iteration
for _ in range(5):
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
Example Speedup
---------------
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
Transformer-based models. The implementation involves several unary and binary
element-wise operations:
.. code-block:: python
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
use it with large arrays it will be memory bandwidth bound. However, all of
the operations in the ``gelu`` are fusible into a single kernel with
:func:`compile`. This can speedup both cases considerably.
Let's compare the runtime of the regular function versus the compiled
function. We'll use the following timing helper which does a warm up and
handles synchronization:
.. code-block:: python
import time
def timeit(fun, x):
# warm up
for _ in range(10):
mx.eval(fun(x))
tic = time.perf_counter()
for _ in range(100):
mx.eval(fun(x))
toc = time.perf_counter()
tpi = 1e3 * (toc - tic) / 100
print(f"Time per iteration {tpi:.3f} (ms)")
Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging
---------
When a compiled function is first called, it is traced with placeholder
inputs. This means you can't evaluate arrays (for example to print their
contents) inside compiled functions.
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Crash
return mx.exp(z)
fun(mx.array(5.0))
For debugging, inspecting arrays can be helpful. One way to do that is to
globally disable compilation using the :func:`disable_compile` function or
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
``fun`` is compiled:
.. code-block:: python
@mx.compile
def fun(x):
z = -x
print(z) # Okay
return mx.exp(z)
mx.disable_compile()
fun(mx.array(5.0))
Pure Functions
--------------
Compiled functions are intended to be *pure*; that is they should not have side
effects. For example:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z)
fun(mx.array(1.0), mx.array(2.0))
# Crash!
print(state)
After the first call of ``fun``, the ``state`` list will hold a placeholder
array. The placeholder does not have any data; it is only used to build the
computation graph. Printing such an array results in a crash.
You have two options to deal with this. The first option is to simply return
``state`` as an output:
.. code-block:: python
state = []
@mx.compile
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
_, state = fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
In some cases returning updated state can be pretty inconvenient. Hence,
:func:`compile` has a parameter to capture implicit outputs:
.. code-block:: python
from functools import partial
state = []
# Tell compile to capture state as an output
@partial(mx.compile, outputs=state)
def fun(x, y):
z = x + y
state.append(z)
return mx.exp(z), state
fun(mx.array(1.0), mx.array(2.0))
# Prints [array(3, dtype=float32)]
print(state)
This is particularly useful for compiling a function which includes an update
to a container of arrays, as is commonly done when training the parameters of a
:class:`mlx.nn.Module`.
Compiled functions will also treat any inputs not in the parameter list as
constants. For example:
.. code-block:: python
state = [mx.array(1.0)]
@mx.compile
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Still prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
In order to have the change of state reflected in the outputs of ``fun`` you
again have two options. The first option is to simply pass ``state`` as input
to the function. In some cases this can be pretty inconvenient. Hence,
:func:`compile` also has a parameter to capture implicit inputs:
.. code-block:: python
from functools import partial
state = [mx.array(1.0)]
# Tell compile to capture state as an input
@partial(mx.compile, inputs=state)
def fun(x):
return x + state[0]
# Prints array(2, dtype=float32)
print(fun(mx.array(1.0)))
# Update state
state[0] = mx.array(5.0)
# Prints array(6, dtype=float32)
print(fun(mx.array(1.0)))
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
of a common setup: training a model with :obj:`mlx.nn.Module` using an
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Perform 10 steps of gradient descent
for it in range(10):
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
# 4 examples with 10 features each
x = mx.random.uniform(shape=(4, 10))
# 0, 1 targets
y = mx.array([0, 1, 0, 1])
# Simple linear model
model = nn.Linear(10, 1)
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
def loss_fn(model, x, y):
logits = model(x).squeeze()
return nn.losses.binary_cross_entropy(logits, y)
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
# Perform 10 steps of gradient descent
for it in range(10):
loss = step(x, y)
# Evaluate the model and optimizer state
mx.eval(state)
print(loss)
.. note::
If you are using a module which performs random sampling such as
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
optimizer.state, mx.random.state]``.
.. note::
For more examples of compiling full training graphs checkout the `MLX
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
Transformations with Compile
----------------------------
In MLX function transformations are composable. You can apply any function
transformation to the output of any other function transformation. For more on
this, see the documentation on :ref:`function transforms
<function_transforms>`.
Compiling transformed functions works just as expected:
.. code-block:: python
grad_fn = mx.grad(mx.exp)
compiled_grad_fn = mx.compile(grad_fn)
# Prints: array(2.71828, dtype=float32)
print(grad_fn(mx.array(1.0)))
# Also prints: array(2.71828, dtype=float32)
print(compiled_grad_fn(mx.array(1.0)))
.. note::
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`
the most opportunity to optimize the computation graph:
.. code-block:: python
@mx.compile
def inner(x):
return mx.exp(-mx.abs(x))
def outer(x):
inner(inner(x))
# Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled
fun = mx.compile(outer)

View File

@@ -1,191 +0,0 @@
.. _function_transforms:
Function Transforms
===================
.. currentmodule:: mlx.core
MLX uses composable function transformations for automatic differentiation,
vectorization, and compute graph optimizations. To see the complete list of
function transformations check-out the :ref:`API documentation <transforms>`.
The key idea behind composable function transformations is that every
transformation returns a function which can be further transformed.
Here is a simple example:
.. code-block:: shell
>>> dfdx = mx.grad(mx.sin)
>>> dfdx(mx.array(mx.pi))
array(-1, dtype=float32)
>>> mx.cos(mx.array(mx.pi))
array(-1, dtype=float32)
The output of :func:`grad` on :func:`sin` is simply another function. In this
case it is the gradient of the sine function which is exactly the cosine
function. To get the second derivative you can do:
.. code-block:: shell
>>> d2fdx2 = mx.grad(mx.grad(mx.sin))
>>> d2fdx2(mx.array(mx.pi / 2))
array(-1, dtype=float32)
>>> mx.sin(mx.array(mx.pi / 2))
array(1, dtype=float32)
Using :func:`grad` on the output of :func:`grad` is always ok. You keep
getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
depth. See the following sections for more information on :ref:`automatic
differentiation <auto diff>` and :ref:`automatic vectorization <vmap>`.
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
Automatic Differentiation
-------------------------
.. _auto diff:
Automatic differentiation in MLX works on functions rather than on implicit
graphs.
.. note::
If you are coming to MLX from PyTorch, you no longer need functions like
``backward``, ``zero_grad``, and ``detach``, or properties like
``requires_grad``.
The most basic example is taking the gradient of a scalar-valued function as we
saw above. You can use the :func:`grad` and :func:`value_and_grad` function to
compute gradients of more complex functions. By default these functions compute
the gradient with respect to the first argument:
.. code-block:: python
def loss_fn(w, x, y):
return mx.mean(mx.square(w * x - y))
w = mx.array(1.0)
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to w:
grad_fn = mx.grad(loss_fn)
dloss_dw = grad_fn(w, x, y)
# Prints array(-1, dtype=float32)
print(dloss_dw)
# To get the gradient with respect to x we can do:
grad_fn = mx.grad(loss_fn, argnums=1)
dloss_dx = grad_fn(w, x, y)
# Prints array([-1, 1], dtype=float32)
print(dloss_dx)
One way to get the loss and gradient is to call ``loss_fn`` followed by
``grad_fn``, but this can result in a lot of redundant work. Instead, you
should use :func:`value_and_grad`. Continuing the above example:
.. code-block:: python
# Computes the gradient of loss_fn with respect to w:
loss_and_grad_fn = mx.value_and_grad(loss_fn)
loss, dloss_dw = loss_and_grad_fn(w, x, y)
# Prints array(1, dtype=float32)
print(loss)
# Prints array(-1, dtype=float32)
print(dloss_dw)
You can also take the gradient with respect to arbitrarily nested Python
containers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or
:obj:`dict`).
Suppose we wanted a weight and a bias parameter in the above example. A nice
way to do that is the following:
.. code-block:: python
def loss_fn(params, x, y):
w, b = params["weight"], params["bias"]
h = w * x + b
return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to both the
# weight and bias:
grad_fn = mx.grad(loss_fn)
grads = grad_fn(params, x, y)
# Prints
# {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}
print(grads)
Notice the tree structure of the parameters is preserved in the gradients.
In some cases you may want to stop gradients from propagating through a
part of the function. You can use the :func:`stop_gradient` for that.
Automatic Vectorization
-----------------------
.. _vmap:
Use :func:`vmap` to automate vectorizing complex functions. Here we'll go
through a basic and contrived example for the sake of clarity, but :func:`vmap`
can be quite powerful for more complex functions which are difficult to optimize
by hand.
.. warning::
Some operations are not yet supported with :func:`vmap`. If you encounter an error
like: ``ValueError: Primitive's vmap not implemented.`` file an `issue
<https://github.com/ml-explore/mlx/issues>`_ and include your function.
We will prioritize including it.
A naive way to add the elements from two sets of vectors is with a loop:
.. code-block:: python
xs = mx.random.uniform(shape=(4096, 100))
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
Instead you can use :func:`vmap` to automatically vectorize the addition:
.. code-block:: python
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
where the vectorized axes should be in the outputs.
Let's time these two different versions:
.. code-block:: python
import timeit
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
vectorized version takes only ``0.025`` seconds, more than ten times faster.
Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

View File

@@ -18,9 +18,9 @@ describe below.
Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Lazy evaluation lets us record a compute graph without actually doing any
Lazy evaluation let's us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations.
:func:`vmap` and graph optimizations like :func:`simplify`.
Currently, MLX does not compile and rerun compute graphs. They are all
generated dynamically. However, lazy evaluation makes it much easier to

View File

@@ -49,7 +49,7 @@ it will be added. You can load the array with:
.. code-block:: shell
>>> mx.load("array.npy")
>>> mx.load("array.npy", a)
array([1], dtype=float32)
Here's an example of saving several arrays to a single file:

View File

@@ -8,4 +8,3 @@ endfunction(build_example)
build_example(tutorial.cpp)
build_example(linear_regression.cpp)
build_example(logistic_regression.cpp)
build_example(metal_capture.cpp)

View File

@@ -1,31 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
// To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1.
metal::start_capture("mlx_trace.gputrace");
// Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each
// stream's command queue.
auto s2 = new_stream(Device::gpu);
auto s3 = new_stream(Device::gpu);
auto a = arange(1.f, 10.f, 1.f, float32, s2);
auto b = arange(1.f, 10.f, 1.f, float32, s3);
auto x = add(a, a, s2);
auto y = add(b, b, s3);
// The multiply will happen on the default stream.
std::cout << multiply(x, y) << std::endl;
metal::stop_capture();
}

View File

@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.27)
cmake_minimum_required(VERSION 3.24)
project(_ext LANGUAGES CXX)
project(mlx_sample_extensions LANGUAGES CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17)
@@ -11,12 +11,8 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
# ----------------------------- Extensions -----------------------------
@@ -42,6 +38,7 @@ target_link_libraries(mlx_ext PUBLIC mlx)
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
@@ -57,15 +54,13 @@ if(MLX_BUILD_METAL)
endif()
# ----------------------------- Python Bindings -----------------------------
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
# ----------------------------- Pybind -----------------------------
pybind11_add_module(
mlx_sample_extensions
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
endif()
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()

View File

@@ -1,18 +0,0 @@
## Build the extensions
```
pip install -e .
```
For faster builds during development, you can also pre-install the requirements:
```
pip install -r requirements.txt
```
And then run:
```
python setup.py build_ext -j8 --inplace
```

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <iostream>
@@ -43,7 +43,7 @@ array axpby(
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32)
auto out_dtype = is_floating_point(promoted_dtype)
? promoted_dtype
: promote_types(promoted_dtype, float32);
@@ -61,7 +61,7 @@ array axpby(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
std::make_unique<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
@@ -104,14 +104,11 @@ void axpby_impl(
}
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
@@ -150,7 +147,11 @@ void axpby_impl_accelerate(
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
@@ -174,13 +175,10 @@ void axpby_impl_accelerate(
}
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
@@ -191,16 +189,14 @@ void Axpby::eval_cpu(
}
// Fall back to common backend if specializations are not available
eval(inputs, outputs);
eval(inputs, out);
}
#else // Accelerate not available
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
eval(inputs, outputs);
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
#endif
@@ -212,14 +208,11 @@ void Axpby::eval_cpu(
#ifdef _METAL_
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
@@ -302,9 +295,7 @@ void Axpby::eval_gpu(
#else // Metal is not available
/** Fail evaluation on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& out) {
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("Axpby has no GPU implementation.");
}
@@ -315,7 +306,7 @@ void Axpby::eval_gpu(
///////////////////////////////////////////////////////////////////////////////
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
array Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
@@ -330,33 +321,32 @@ std::vector<array> Axpby::jvp(
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
return multiply(scale_arr, tangents[0], stream());
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
}
}
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const array& cotan,
const std::vector<int>& argnums) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
auto scale_arr = array(scale, cotan.dtype());
vjps.push_back(multiply(scale_arr, cotan, stream()));
}
return vjps;
}
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation.");
@@ -368,4 +358,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -42,13 +42,11 @@ class Axpby : public Primitive {
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
array jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
@@ -56,9 +54,8 @@ class Axpby : public Primitive {
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
const array& cotan,
const std::vector<int>& argnums) override;
/**
* The primitive must know how to vectorize itself across
@@ -66,7 +63,7 @@ class Axpby : public Primitive {
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<array>, std::vector<int>> vmap(
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
@@ -83,7 +80,7 @@ class Axpby : public Primitive {
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
void eval(const std::vector<array>& inputs, array& out);
};
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -1,31 +1,31 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/variant.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "axpby/axpby.h"
namespace nb = nanobind;
using namespace nb::literals;
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
@@ -37,5 +37,5 @@ NB_MODULE(_ext, m) {
Returns:
array: ``alpha * x + beta * y``
)");
}
)pbdoc");
}

View File

@@ -1,8 +0,0 @@
[build-system]
requires = [
"setuptools>=42",
"cmake>=3.24",
"mlx>=0.9.0",
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
]
build-backend = "setuptools.build_meta"

View File

@@ -1,4 +0,0 @@
setuptools>=42
cmake>=3.24
mlx>=0.9.0
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f

View File

@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
# Copyright © 2023 Apple Inc.
from setuptools import setup
@@ -9,11 +9,11 @@ if __name__ == "__main__":
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_dir={"": "."},
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev": []},
zip_safe=False,
python_requires=">=3.8",
)

View File

@@ -41,6 +41,6 @@ error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
throughput = num_iters / (toc - tic)
print(
f"Loss {loss.item():.5f}, L2 distance: |w-w*| = {error_norm:.5f}, "
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
f"Throughput {throughput:.5f} (it/s)"
)

View File

@@ -3,10 +3,8 @@ target_sources(
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
@@ -21,7 +19,7 @@ target_sources(
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
else()
target_sources(

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <functional>
@@ -12,6 +12,16 @@ namespace mlx::core {
namespace {
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
std::vector<size_t> strides(shape.size());
size_t cum_prod = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = cum_prod;
cum_prod *= shape[i];
}
return {cum_prod, strides};
}
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
bool in_tracing() {
@@ -27,27 +37,26 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
}
array::array(
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
const std::vector<array>& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
std::move(shape),
shape,
dtype,
std::move(primitive),
std::move(inputs))) {}
inputs)) {}
std::vector<array> array::make_arrays(
std::vector<std::vector<int>> shapes,
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs) {
std::vector<array> outputs;
for (size_t i = 0; i < shapes.size(); ++i) {
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
for (int i = 0; i < shapes.size(); ++i) {
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
}
// For each node in |outputs|, its siblings are the other nodes.
for (size_t i = 0; i < outputs.size(); ++i) {
for (int i = 0; i < outputs.size(); ++i) {
auto siblings = outputs;
siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i);
@@ -62,30 +71,17 @@ array::array(std::initializer_list<float> data)
init(data.begin());
}
array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
dtype)) {
init(data.begin());
}
/* Build an array from a shared buffer */
array::array(
allocator::Buffer data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
deleter_t deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
set_data(data, deleter);
}
void array::detach() {
for (auto& s : array_desc_->siblings) {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
@@ -93,13 +89,7 @@ void array::detach() {
}
void array::eval() {
// Ensure the array is ready to be read
if (status() == Status::scheduled) {
event().wait();
set_status(Status::available);
} else if (status() == Status::unscheduled) {
mlx::core::eval({*this});
}
mlx::core::eval({*this});
}
bool array::is_tracer() const {
@@ -148,89 +138,23 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
}
void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::ArrayDesc::init() {
strides.resize(shape.size());
size = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = size;
size *= shape[i];
}
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
init();
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
: shape(shape), dtype(dtype) {
std::tie(size, strides) = cum_prod(shape);
}
array::ArrayDesc::ArrayDesc(
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
: shape(std::move(shape)),
const std::vector<array>& inputs)
: shape(shape),
dtype(dtype),
status(Status::unscheduled),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
init();
}
array::ArrayDesc::~ArrayDesc() {
// When an array description is destroyed it will delete a bunch of arrays
// that may also destory their corresponding descriptions and so on and so
// forth.
//
// This calls recursively the destructor and can result in stack overflow, we
// instead put them in a vector and destroy them one at a time resulting in a
// max stack depth of 2.
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
for (array& a : inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
while (!for_deletion.empty()) {
// top is going to be deleted at the end of the block *after* the arrays
// with inputs have been moved into the vector
auto top = std::move(for_deletion.back());
for_deletion.pop_back();
for (array& a : top->inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
}
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
: arr(arr), idx(idx) {
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
inputs(inputs) {
std::tie(size, strides) = cum_prod(shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}

View File

@@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
@@ -9,7 +8,6 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
namespace mlx::core {
@@ -33,7 +31,7 @@ class array {
template <typename It>
array(
It data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>());
@@ -43,19 +41,16 @@ class array {
/* Special case so empty lists default to float32. */
array(std::initializer_list<float> data);
/* Special case so array({}, type) is an empty array. */
array(std::initializer_list<int> data, Dtype dtype);
template <typename T>
array(
std::initializer_list<T> data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */
array(
allocator::Buffer data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
deleter_t deleter = allocator::free);
@@ -114,15 +109,6 @@ class array {
return array_desc_->strides;
};
/**
* Get the stride of the corresponding dimension.
*
* This function supports negative indexing and provides
* bounds checking. */
size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
};
/** Get the arrays data type. */
Dtype dtype() const {
return array_desc_->dtype;
@@ -135,16 +121,17 @@ class array {
template <typename T>
T item();
template <typename T>
T item() const;
struct ArrayIterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = size_t;
using value_type = const array;
using reference = value_type;
explicit ArrayIterator(const array& arr, int idx = 0);
explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) {
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
}
}
reference operator*() const;
@@ -184,15 +171,15 @@ class array {
*/
array(
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
const std::vector<array>& inputs);
static std::vector<array> make_arrays(
std::vector<std::vector<int>> shapes,
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
/** A unique identifier for an array. */
@@ -232,11 +219,6 @@ class array {
return *(array_desc_->primitive);
};
/** A shared pointer to the array's primitive. */
std::shared_ptr<Primitive>& primitive_ptr() const {
return array_desc_->primitive;
};
/** Check if the array has an attached primitive or is a leaf node. */
bool has_primitive() const {
return array_desc_->primitive != nullptr;
@@ -251,11 +233,6 @@ class array {
return array_desc_->inputs;
}
/** True indicates the arrays buffer is safe to reuse */
bool is_donatable() const {
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
}
/** The array's siblings. */
const std::vector<array>& siblings() const {
return array_desc_->siblings;
@@ -266,17 +243,6 @@ class array {
array_desc_->position = position;
}
/** The i-th output of the array's primitive. */
const array& output(int i) const {
if (i == array_desc_->position) {
return *this;
} else if (i < array_desc_->position) {
return siblings()[i];
} else {
return siblings()[i + 1];
}
};
/** The outputs of the array's primitive (i.e. this array and
* its siblings) in the order the primitive expects. */
std::vector<array> outputs() const {
@@ -309,12 +275,6 @@ class array {
return array_desc_->data->buffer;
};
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
@@ -325,27 +285,9 @@ class array {
return static_cast<T*>(array_desc_->data_ptr);
};
enum Status { unscheduled, scheduled, available };
bool is_available() const {
return status() == Status::available;
}
const Status status() const {
return array_desc_->status;
}
void set_status(Status s) const {
array_desc_->status = s;
}
// Get the array's shared event
Event& event() const {
return array_desc_->event;
}
// Attach an event to a not yet evaluated array
void attach_event(Event e) const {
array_desc_->event = std::move(e);
// Check if the array has been evaluated
bool is_evaled() const {
return array_desc_->data != nullptr;
}
// Mark the array as a tracer array (true) or not.
@@ -373,15 +315,6 @@ class array {
void copy_shared_buffer(const array& other);
void move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_;
}
@@ -396,12 +329,7 @@ class array {
std::vector<size_t> strides;
size_t size;
Dtype dtype;
std::shared_ptr<Primitive> primitive;
Status status;
// An event on the array used for synchronization
Event event;
std::shared_ptr<Primitive> primitive{nullptr};
// Indicates an array is being used in a graph transform
// and should not be detached from the graph
@@ -409,7 +337,7 @@ class array {
// This is a shared pointer so that *different* arrays
// can share the underlying data buffer.
std::shared_ptr<Data> data;
std::shared_ptr<Data> data{nullptr};
// Properly offset data pointer
void* data_ptr{nullptr};
@@ -429,26 +357,20 @@ class array {
// The arrays position in the output list
uint32_t position{0};
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc(
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
~ArrayDesc();
private:
// Initialize size, strides, and other metadata
void init();
const std::vector<array>& inputs);
};
// The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs
// and the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_;
// and a the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
};
template <typename T>
@@ -460,9 +382,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
template <typename It>
array::array(
It data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
init(data);
}
@@ -479,9 +401,9 @@ array::array(
template <typename T>
array::array(
std::initializer_list<T> data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
if (data.size() != size()) {
throw std::invalid_argument(
"Data size and provided shape mismatch in array construction.");
@@ -498,19 +420,6 @@ T array::item() {
return *data<T>();
}
template <typename T>
T array::item() const {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
if (status() == Status::unscheduled) {
throw std::invalid_argument(
"item() const can only be called on evaled arrays");
}
const_cast<array*>(this)->eval();
return *data<T>();
}
template <typename It>
void array::init(It src) {
set_data(allocator::malloc(size() * size_of(dtype())));
@@ -557,15 +466,4 @@ void array::init(It src) {
}
}
/* Utilities for determining whether a template parameter is array. */
template <typename T>
inline constexpr bool is_array_v =
std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
template <typename... T>
inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
template <typename... T>
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
@@ -29,16 +29,12 @@ std::tuple<bool, size_t, array> check_transpose(const array& arr) {
}
}
inline void matmul_cblas_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
@@ -46,14 +42,6 @@ inline void matmul_cblas_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
@@ -62,34 +50,21 @@ inline void matmul_cblas_general(
M,
N,
K,
alpha, // alpha
1.0f, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
0.0f, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_cblas_general(a_pre, b_pre, out);
}
inline void matmul_bnns_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
// TODO: Update to utilize BNNS broadcasting
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
@@ -97,19 +72,11 @@ inline void matmul_bnns_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,
/* float alpha = */ 1.0,
/* float beta = */ 0.0,
/* bool transA = */ a_transposed,
/* bool transB = */ b_transposed,
/* bool quadratic = */ false,
@@ -190,46 +157,6 @@ inline void matmul_bnns_general(
BNNSFilterDestroy(bnns_filter);
}
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
// TODO: Update to utilize BNNS broadcasting
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_bnns_general(a_pre, b_pre, out);
}
template <typename T>
inline void mask_matrix(
T* data,
const bool* mask,
int tile_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str) {
int tX = (X + tile_size - 1) / tile_size;
int tY = (Y + tile_size - 1) / tile_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
if (!do_mask) {
int loc_x = i * tile_size;
int loc_y = j * tile_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
int size_x = std::min(tile_size, X - loc_x);
int size_y = std::min(tile_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
}
}
}
}
}
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -239,16 +166,4 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
return matmul_bnns(inputs[0], inputs[1], out);
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy(c, out, ctype);
if (out.dtype() == float32) {
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
}
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <cmath>
@@ -31,15 +31,10 @@ DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
@@ -55,40 +50,45 @@ DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT_MULTI(DivMod)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, size);
} else if (in.dtype() == int32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, size);
} else if (is_unsigned(in.dtype())) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
eval(inputs, out);
unary(in, out, AbsOp());
}
}
@@ -136,8 +136,12 @@ void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -148,8 +152,12 @@ void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -160,8 +168,12 @@ void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -172,8 +184,12 @@ void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -184,8 +200,12 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -196,8 +216,12 @@ void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -209,23 +233,30 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
if (in.flags().contiguous) {
auto allocfn = [&in, &out]() {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
};
// Use accelerate functions if possible
if (in.dtype() == float32 && out.dtype() == uint32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vfixu32(
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == float32 && out.dtype() == int32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == uint32 && out.dtype() == float32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vfltu32(
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
} else if (in.dtype() == int32 && out.dtype() == float32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
}
@@ -237,8 +268,12 @@ void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -249,8 +284,12 @@ void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -295,14 +334,57 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
// TODO: Avoid code duplication with the common backend.
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
a,
b,
out,
RemainderFn{},
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* o, auto n) {
int num_el = n;
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
});
} else {
binary(a, b, out, RemainderFn{});
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
throw std::invalid_argument(
@@ -311,19 +393,6 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpm1f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -340,8 +409,12 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
switch (base_) {
case Base::e:
vvlogf(
@@ -365,11 +438,15 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
throw std::invalid_argument(
@@ -378,6 +455,47 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x > y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x < y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@@ -407,8 +525,13 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return -x; });
}
@@ -421,13 +544,7 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
int size = a.size();
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
} else {
eval(inputs, out);
@@ -469,8 +586,12 @@ void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -481,8 +602,12 @@ void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -493,8 +618,12 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return x * x; });
@@ -505,8 +634,12 @@ void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
if (recip_) {
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
} else {
@@ -561,8 +694,12 @@ void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -573,8 +710,12 @@ void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);

View File

@@ -24,6 +24,8 @@ void _qmm_t_4_64(
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;

View File

@@ -10,65 +10,78 @@
namespace mlx::core {
namespace {
template <typename T, typename VT>
struct MinReduction {
T operator()(const T& a, const T& b) {
return std::min(a, b);
}
VT operator()(VT a, VT b) {
return simd_min(a, b);
}
};
template <typename T, typename VT>
struct MaxReduction {
T operator()(const T& a, const T& b) {
return std::max(a, b);
}
VT operator()(VT a, VT b) {
return simd_max(a, b);
}
};
template <typename T, typename VT>
struct SumReduction {
T operator()(const T& a, const T& b) {
return a + b;
}
VT operator()(VT a, VT b) {
return a + b;
}
};
template <typename T, typename VT, int N, typename Reduction>
struct StridedReduce {
void operator()(const T* x, T* accum, int size, size_t stride) {
Reduction op;
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = op((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = op(*a, *x);
a++;
x++;
}
template <typename T, typename VT, int N>
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
VT val = (*(VT*)x);
*(VT*)a += val;
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a++ += *x++;
}
}
};
}
} // namespace
// TODO: Add proper templates for the strided reduce algorithm so we don't have
// to write max/min/sum etc.
template <typename T, typename VT, int N>
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::max(*a, *x);
a++;
x++;
}
}
}
template <typename T, typename VT, int N>
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::min(*a, *x);
a++;
x++;
}
}
}
template <typename T, typename VT, int N>
void _vectorized_sum(const T* x, T* accum, int size) {
VT _sum = {0};
while (size >= N) {
_sum += (*(VT*)x);
x += N;
size -= N;
}
T sum = _sum[0];
for (int i = 1; i < N; i++) {
sum += _sum[i];
}
*accum += sum;
}
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
@@ -81,11 +94,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
0,
StridedReduce<
float,
simd_float16,
16,
SumReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_sum<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float acc;
vDSP_sve((const float*)x, 1, &acc, size);
@@ -99,11 +111,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
-std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MaxReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_max<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float max;
vDSP_maxv((const float*)x, 1, &max, size);
@@ -117,11 +128,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MinReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_min<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float min;
vDSP_minv((const float*)x, 1, &min, size);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <limits>
@@ -201,7 +201,7 @@ struct NeonFp16SimdOps {
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N>
template <typename T, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
@@ -218,21 +218,13 @@ void softmax(const array& in, array& out) {
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
size_t s = M;
while (s >= N) {
VT vals;
if constexpr (std::is_same<T, AccT>::value) {
vals = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vals[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vmaximum = ops.max(vals, vmaximum);
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = ops.reduce_max(vmaximum);
T maximum = ops.reduce_max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
maximum = std::max(maximum, *current_in_ptr);
current_in_ptr++;
}
@@ -242,29 +234,18 @@ void softmax(const array& in, array& out) {
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
VT vexp;
if constexpr (std::is_same<T, AccT>::value) {
vexp = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vexp = ops.exp(ops.sub(vexp, maximum));
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, vexp);
}
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
ops.store(current_out_ptr, vexp);
*(VT*)current_out_ptr = vexp;
vnormalizer = ops.add(vnormalizer, vexp);
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = ops.reduce_add(vnormalizer);
T normalizer = ops.reduce_add(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if (std::is_same<T, AccT>::value) {
*current_out_ptr = _exp;
}
T _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = _exp;
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
@@ -273,33 +254,14 @@ void softmax(const array& in, array& out) {
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
} else {
VT vexp;
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
for (int i = 0; i < N; ++i) {
current_out_ptr[i] = vexp[i];
}
current_in_ptr += N;
}
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
*current_out_ptr *= normalizer;
current_out_ptr++;
}
}
@@ -312,12 +274,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
@@ -346,29 +303,15 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
"Softmax is defined only for floating point types");
break;
case float32:
softmax<
float,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
in, out);
break;
case float16:
if (precise_) {
softmax<
float16_t,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
}
softmax<
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
break;
case bfloat16:
eval(inputs, out);

View File

@@ -1,73 +1,19 @@
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE)
else()
set(COMPILER ${CMAKE_CXX_COMPILER})
endif()
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${COMPILER}
${PROJECT_SOURCE_DIR}
${CLANG}
DEPENDS make_compiled_preamble.sh
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h
)
add_custom_target(
cpu_compiled_preamble
DEPENDS compiled_preamble.cpp
)
add_dependencies(mlx cpu_compiled_preamble)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)
if (IOS)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp
)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
)
endif()

View File

@@ -7,7 +7,6 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/binary_two.h"
#include "mlx/backend/common/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -74,7 +73,7 @@ void Add::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Add());
binary(a, b, out, [](auto x, auto y) { return x + y; });
}
void DivMod::eval(
@@ -136,59 +135,93 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Divide());
binary(a, b, out, [](auto x, auto y) { return x / y; });
}
struct RemainderFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return std::fmod(numerator, denominator);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
};
void Remainder::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Remainder());
binary(a, b, out, RemainderFn{});
}
void Equal::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (equal_nan_) {
comparison_op(inputs[0], inputs[1], out, detail::NaNEqual());
comparison_op(inputs[0], inputs[1], out, [](auto x, auto y) {
return x == y || (std::isnan(x) && std::isnan(y));
});
} else {
comparison_op(inputs[0], inputs[1], out, detail::Equal());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x == y; });
}
}
void Greater::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x > y; });
}
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x >= y; });
}
void Less::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x < y; });
}
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x <= y; });
}
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
auto op = [](auto x, auto y) {
constexpr float inf = std::numeric_limits<float>::infinity();
auto maxval = (x > y) ? x : y;
auto minval = (x > y) ? y : x;
return (minval == -inf || maxval == inf)
? maxval
: static_cast<decltype(x)>(
maxval + std::log1p(std::exp(minval - maxval)));
};
if (is_floating_point(out.dtype())) {
if (out.dtype() == float32) {
binary_op<float>(a, b, out, op);
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, op);
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, op);
} else {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
}
} else {
throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with"
@@ -200,97 +233,65 @@ void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Maximum());
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
void Minimum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Minimum());
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
void Multiply::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Multiply());
binary(a, b, out, [](auto x, auto y) { return x * y; });
}
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x != y; });
}
struct PowerFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
return std::pow(base, exp);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
if (exp < 0) {
throw std::invalid_argument(
"Integers cannot be raise to negative powers");
}
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
};
void Power::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Power());
binary(a, b, out, PowerFn{});
}
void Subtract::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Subtract());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd());
break;
case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr());
break;
case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor());
break;
case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift());
break;
case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift());
break;
}
binary(a, b, out, [](auto x, auto y) { return x - y; });
}
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
@@ -9,7 +10,7 @@ namespace mlx::core {
namespace {
enum class BinaryOpType {
enum BinaryOpType {
ScalarScalar,
ScalarVector,
VectorScalar,
@@ -20,17 +21,17 @@ enum class BinaryOpType {
BinaryOpType get_binary_op_type(const array& a, const array& b) {
BinaryOpType bopt;
if (a.data_size() == 1 && b.data_size() == 1) {
bopt = BinaryOpType::ScalarScalar;
bopt = ScalarScalar;
} else if (a.data_size() == 1 && b.flags().contiguous) {
bopt = BinaryOpType::ScalarVector;
bopt = ScalarVector;
} else if (b.data_size() == 1 && a.flags().contiguous) {
bopt = BinaryOpType::VectorScalar;
bopt = VectorScalar;
} else if (
a.flags().row_contiguous && b.flags().row_contiguous ||
a.flags().col_contiguous && b.flags().col_contiguous) {
bopt = BinaryOpType::VectorVector;
bopt = VectorVector;
} else {
bopt = BinaryOpType::General;
bopt = General;
}
return bopt;
}
@@ -39,83 +40,29 @@ void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt,
bool donate_with_move = false) {
BinaryOpType bopt) {
switch (bopt) {
case BinaryOpType::ScalarScalar:
case ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
}
case ScalarVector:
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
break;
case BinaryOpType::VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
case VectorScalar:
case VectorVector:
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
break;
case BinaryOpType::VectorVector:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
break;
case BinaryOpType::General:
if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (
b.is_donatable() && b.flags().row_contiguous &&
b.itemsize() == out.itemsize() && b.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
case General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
break;
}
}
@@ -424,25 +371,25 @@ void binary_op(
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::ScalarScalar) {
if (bopt == ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
if (bopt == ScalarVector) {
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
if (bopt == VectorScalar) {
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
if (bopt == VectorVector) {
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
@@ -475,17 +422,17 @@ void binary_op(
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
bopt = VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
bopt = VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
bopt = ScalarVector;
dim = d;
}
@@ -495,20 +442,20 @@ void binary_op(
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General;
bopt = General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case BinaryOpType::VectorVector:
case VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
break;
case BinaryOpType::VectorScalar:
case VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
break;
case BinaryOpType::ScalarVector:
case ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
break;
default:

View File

@@ -260,14 +260,14 @@ void binary_op(
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::ScalarScalar) {
if (bopt == ScalarScalar) {
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
if (bopt == ScalarVector) {
opsv(
a.data<T>(),
b.data<T>(),
@@ -278,7 +278,7 @@ void binary_op(
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
if (bopt == VectorScalar) {
opvs(
a.data<T>(),
b.data<T>(),
@@ -289,7 +289,7 @@ void binary_op(
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
if (bopt == VectorVector) {
opvv(
a.data<T>(),
b.data<T>(),
@@ -327,17 +327,17 @@ void binary_op(
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
bopt = VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
bopt = VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
bopt = ScalarVector;
dim = d;
}
@@ -347,20 +347,20 @@ void binary_op(
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = BinaryOpType::General;
bopt = General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case BinaryOpType::VectorVector:
case VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
break;
case BinaryOpType::VectorScalar:
case VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
break;
case BinaryOpType::ScalarVector:
case ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
break;
default:

View File

@@ -1,227 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void print_constant(std::ostream& os, const array& x) {
switch (x.dtype()) {
case float32:
return print_float_constant<float>(os, x);
case float16:
return print_float_constant<float16_t>(os, x);
case bfloat16:
return print_float_constant<bfloat16_t>(os, x);
case complex64:
return print_complex_constant<complex64_t>(os, x);
case int8:
return print_int_constant<int8_t>(os, x);
case int16:
return print_int_constant<int16_t>(os, x);
case int32:
return print_int_constant<int32_t>(os, x);
case int64:
return print_int_constant<int64_t>(os, x);
case uint8:
return print_int_constant<uint8_t>(os, x);
case uint16:
return print_int_constant<uint16_t>(os, x);
case uint32:
return print_int_constant<uint32_t>(os, x);
case uint64:
return print_int_constant<uint64_t>(os, x);
case bool_:
os << std::boolalpha << x.item<bool>();
return;
default:
throw std::runtime_error("Unsupported constant type");
}
}
std::string get_type_string(Dtype d) {
switch (d) {
case float32:
return "float";
case float16:
return "float16_t";
case bfloat16:
return "bfloat16_t";
case complex64:
return "complex64_t";
case bool_:
return "bool";
case int8:
return "int8_t";
case int16:
return "int16_t";
case int32:
return "int32_t";
case int64:
return "int64_t";
case uint8:
return "uint8_t";
case uint16:
return "uint16_t";
case uint32:
return "uint32_t";
case uint64:
return "uint64_t";
default: {
std::ostringstream msg;
msg << "Unsupported compilation type " << d;
throw std::runtime_error(msg.str());
}
}
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (auto& x : inputs) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape) {
bool contiguous = true;
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (const auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
} else if (non_scalar_inputs == 0 && !shape.empty()) {
contiguous = false;
}
return contiguous;
}
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers /* = false */) {
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Correct size
// - Not a scalar
// - Donatable
// - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o++].move_shared_buffer(in);
} else {
outputs[o++].copy_shared_buffer(in);
}
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
} else {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
}
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
}
} // namespace mlx::core

View File

@@ -1,70 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <iomanip>
#include <sstream>
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/primitives.h"
namespace mlx::core {
inline bool is_static_cast(const Primitive& p) {
return (
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids);
std::string get_type_string(Dtype d);
template <typename T>
void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< x.item<T>() << std::setprecision(old_precision);
}
template <typename T>
void print_int_constant(std::ostream& os, const array& x) {
os << x.item<T>();
}
template <typename T>
void print_complex_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision();
T constant = x.item<T>();
os << get_type_string(x.dtype()) << "("
<< std::setprecision(std::numeric_limits<float>::digits10 + 1)
<< constant.real() << ", " << constant.imag() << ")"
<< std::setprecision(old_precision);
}
void print_constant(std::ostream& os, const array& x);
inline bool is_scalar(const array& x) {
return x.ndim() == 0;
}
// Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape);
// Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers = false);
} // namespace mlx::core

View File

@@ -1,356 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include <filesystem>
#include <list>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
namespace mlx::core {
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::string& source_code = "") {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
if (!lib) {
std::ostringstream msg;
msg << "Could not load C++ shared library " << dlerror();
throw std::runtime_error(msg.str());
}
}
~DLib() {
dlclose(lib);
}
void* lib;
};
// Statics to cache compiled libraries and functions
static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels;
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
if (source_code.empty()) {
return nullptr;
}
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255
// characters. Clip file name with a little extra room and append a 16
// character hash.
constexpr int max_file_name_length = 245;
if (kernel_name.size() > max_file_name_length) {
std::ostringstream file_name;
file_name
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
auto file_id = std::hash<std::string>{}(kernel_name);
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
kernel_file_name = file_name.str();
} else {
kernel_file_name = kernel_name;
}
std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_file_name << ".so";
auto shared_lib_path = get_temp_file(shared_lib_name.str());
bool lib_exists = false;
{
std::ifstream f(shared_lib_path.c_str());
lib_exists = f.good();
}
if (!lib_exists) {
// Open source file and write source code to it
std::ostringstream source_file_name;
source_file_name << kernel_file_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str());
std::ofstream source_file(source_file_path);
source_file << source_code;
source_file.close();
std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
<< source_file_path << " -o " << shared_lib_path;
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
if (return_code) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
<< " with error code " << return_code << "." << std::endl;
throw std::runtime_error(msg.str());
}
}
// load library
libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
<< kernel_name << std::endl
<< dlerror();
throw std::runtime_error(msg.str());
}
kernels.insert({kernel_name, fun});
return fun;
}
inline void build_kernel(
std::ostream& os,
const std::string& kernel_name,
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous,
int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer;
// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
// Skip constants from the input list
if (is_constant(x)) {
continue;
}
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
}
}
// Add the output arguments
for (auto& x : outputs) {
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
if (contiguous) {
os << " for (size_t i = 0; i < size; ++i) {" << std::endl;
} else {
for (int d = 0; d < ndim; ++d) {
os << " for (int i" << d << " = 0; i" << d << " < shape[" << d
<< "]; ++i" << d << ") {" << std::endl;
}
}
// Read the inputs in tmps
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ";" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[i];" << std::endl;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = *"
<< xname << ";" << std::endl;
}
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
if (contiguous) {
os << " " << namer.get_name(x) << "[i] = tmp_" << namer.get_name(x)
<< ";" << std::endl;
} else {
os << " *" << namer.get_name(x) << "++ = tmp_" << namer.get_name(x)
<< ";" << std::endl;
}
}
// Close loops
if (contiguous) {
os << " }" << std::endl;
} else {
for (int d = ndim - 1; d >= 0; --d) {
// Update pointers
for (auto& x : inputs) {
if (is_constant(x) || is_scalar(x)) {
continue;
}
auto& xname = namer.get_name(x);
os << " " << xname << " += " << xname << "_strides[" << d << "];"
<< std::endl;
if (d < ndim - 1) {
os << " " << xname << " -= " << xname << "_strides[" << d + 1 << "]"
<< " * shape[" << d + 1 << "];" << std::endl;
}
}
os << " }" << std::endl;
}
}
// Finish the kernel
os << "}" << std::endl;
}
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
bool contiguous = compiled_check_contiguity(inputs, shape);
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
std::vector<std::vector<size_t>> strides;
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
continue;
}
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
}
// Get the kernel name from the lib
int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) {
kernel_name += std::to_string(shape.size());
}
// Get the function
auto fn_ptr = compile(kernel_name);
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
build_kernel(
kernel,
kernel_name,
inputs_,
outputs_,
tape_,
constant_ids_,
contiguous,
ndim);
// Close extern "C"
kernel << "}" << std::endl;
// Compile and get function pointer
fn_ptr = compile(kernel_name, kernel.str());
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);
for (auto& x : outputs) {
args.push_back(x.data<void>());
}
if (!contiguous) {
args.push_back((void*)outputs[0].shape().data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
fun(args.data());
}
} // namespace mlx::core

View File

@@ -1,23 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
namespace mlx::core {
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is not available so check if the device is a GPU
// device.
namespace detail {
bool compile_available_for_device(const Device& device) {
return device == Device::gpu;
}
} // namespace detail
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error(
"[Compiled::eval_cpu] CPU compialtion not supported on the platform.");
}
} // namespace mlx::core

View File

@@ -1,11 +0,0 @@
// Copyright © 2023-24 Apple Inc.
#pragma once
// clang-format off
#include "mlx/types/half_types.h"
#include "mlx/types/complex.h"
#include "mlx/backend/common/ops.h"
// clang-format on
const char* get_kernel_preamble();

View File

@@ -1,10 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <numeric>
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#include <vecLib/cblas_new.h>
#else
#include <cblas.h>
#endif
@@ -28,25 +27,19 @@ void slow_conv_1D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const std::vector<int>& wt_dilation) {
const T* start_wt_ptr = wt.data<T>();
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int iH = in.shape(1); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1];
const size_t in_stride_C = in.strides()[2];
@@ -61,36 +54,32 @@ void slow_conv_1D(
for (int n = 0; n < N; ++n) {
for (int oh = 0; oh < oH; ++oh) {
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;
for (int o = 0; o < O; ++o) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
int ih = oh * wt_strides[0] - padding[0] + wh * wt_dilation[0];
auto ih_div = std::div(ih, in_dilation[0]);
if (ih >= 0 && ih < iH) {
for (int c = 0; c < C; ++c) {
r += static_cast<float>(
in_ptr[ih * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[c * wt_stride_C]);
} // c
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[(c % C_per_group) * wt_stride_C]);
} // c
} // ih check
} // wh
} // ih check
} // wh
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o
} // g
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o
} // oh
in_ptr += in_stride_N;
out_ptr += out_stride_N;
} // n
}
@@ -101,16 +90,14 @@ void slow_conv_2D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const std::vector<int>& wt_dilation) {
const T* st_wt_ptr = wt.data<T>();
const T* st_in_ptr = in.data<T>();
T* st_out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
@@ -133,8 +120,6 @@ void slow_conv_2D(
const size_t out_stride_W = out.strides()[2];
const size_t out_stride_O = out.strides()[3];
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
auto pt_conv_no_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
@@ -146,10 +131,8 @@ void slow_conv_2D(
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
int ih = ih_base + wh * wt_dilation[0];
int iw = iw_base + ww * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
@@ -170,74 +153,25 @@ void slow_conv_2D(
} // o
};
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int f_wgt_jump_h = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
base_h[i] = wh_base;
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
base_w[j] = ww_base;
}
auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int ih = ih_base + wh * wt_dilation[0];
int iw = iw_base + ww * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
@@ -257,17 +191,13 @@ void slow_conv_2D(
};
int oH_border_0 = 0;
int oH_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_1 = (padding[0] + wt_strides[0] + 1) / wt_strides[0];
int oH_border_2 = (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0];
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_1 = (padding[1] + wt_strides[0] + 1) / wt_strides[1];
int oW_border_2 = (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1];
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
@@ -316,18 +246,15 @@ void dispatch_slow_conv_1D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const std::vector<int>& wt_dilation) {
if (in.dtype() == float32) {
return slow_conv_1D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
return slow_conv_1D<float>(in, wt, out, padding, wt_strides, wt_dilation);
} else if (in.dtype() == float16) {
return slow_conv_1D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
in, wt, out, padding, wt_strides, wt_dilation);
} else if (in.dtype() == bfloat16) {
return slow_conv_1D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
in, wt, out, padding, wt_strides, wt_dilation);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
@@ -340,18 +267,15 @@ void dispatch_slow_conv_2D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const std::vector<int>& wt_dilation) {
if (in.dtype() == float32) {
return slow_conv_2D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
return slow_conv_2D<float>(in, wt, out, padding, wt_strides, wt_dilation);
} else if (in.dtype() == float16) {
return slow_conv_2D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
in, wt, out, padding, wt_strides, wt_dilation);
} else if (in.dtype() == bfloat16) {
return slow_conv_2D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
in, wt, out, padding, wt_strides, wt_dilation);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
@@ -371,15 +295,11 @@ void explicit_gemm_conv_1D_cpu(
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;
auto conv_dtype = float32;
// Pad input
@@ -411,11 +331,6 @@ void explicit_gemm_conv_1D_cpu(
in_padded.strides()[1],
in_padded.strides()[2]};
auto flags = in_padded.flags();
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
std::swap(strided_shape[2], strided_shape[3]);
std::swap(strided_strides[2], strided_strides[3]);
}
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
@@ -430,19 +345,7 @@ void explicit_gemm_conv_1D_cpu(
auto gemm_wt = wt;
auto gemm_out = out;
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
array wt_transpose(
{wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
wt_transpose.copy_shared_buffer(
wt,
{wt.strides(0), wt.strides(2), wt.strides(1)},
wt.flags(),
wt.size(),
0);
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
copy(wt_transpose, gemm_wt, CopyType::General);
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
@@ -454,29 +357,27 @@ void explicit_gemm_conv_1D_cpu(
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
for (int g = 0; g < groups; ++g) {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O_per_group, // N
C_per_group * wH, // K
1.0f, // alpha
in_strided.data<float>() + g * C_per_group * wH, // A
wH * C, // lda
gemm_wt.data<float>() + g * O_per_group * C_per_group * wH, // B
wH * C_per_group, // ldb
0.0f, // beta
gemm_out.data<float>() + g * O_per_group, // C
O // ldc
);
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
@@ -592,16 +493,13 @@ void conv_1D_cpu(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
const std::vector<int>& wt_dilation) {
if (wt_dilation[0] == 1) {
return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation);
}
return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
return dispatch_slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation);
}
void conv_2D_cpu(
@@ -610,11 +508,8 @@ void conv_2D_cpu(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
const std::vector<int>& wt_dilation) {
return dispatch_slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation);
}
} // namespace
@@ -628,26 +523,12 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
// 2D convolution
if (in.ndim() == (2 + 2)) {
return conv_2D_cpu(
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
}
// 1D convolution
else if (in.ndim() == (1 + 2)) {
return conv_1D_cpu(
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
}
// Throw error
else {

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <numeric>
@@ -25,196 +25,121 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim1(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
template <typename SrcT, typename DstT>
void copy_general_dim1(const array& src, array& dst) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[0];
src_idx += src.strides()[0];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim1(const array& src, array& dst) {
return copy_general_dim1<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim2(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
void copy_general_dim2(const array& src, array& dst) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[1];
src_idx += src.strides()[1];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim2(const array& src, array& dst) {
return copy_general_dim2<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim3(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
void copy_general_dim3(const array& src, array& dst) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) {
for (size_t k = 0; k < src.shape()[2]; ++k) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[2];
src_idx += src.strides()[2];
}
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim3(const array& src, array& dst) {
return copy_general_dim3<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim4(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
void copy_general_dim4(const array& src, array& dst) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
for (int ii = 0; ii < data_shape[3]; ++ii) {
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) {
for (size_t k = 0; k < src.shape()[2]; ++k) {
for (size_t ii = 0; ii < src.shape()[3]; ++ii) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[3];
src_idx += src.strides()[3];
}
src_idx += i_strides[2] - i_strides[3] * data_shape[3];
src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3];
}
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim4(const array& src, array& dst) {
return copy_general_dim4<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
void copy_general(const array& src, array& dst) {
switch (src.ndim()) {
case 1:
copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
copy_general_dim1<SrcT, DstT>(src, dst);
return;
case 2:
copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
copy_general_dim2<SrcT, DstT>(src, dst);
return;
case 3:
copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
copy_general_dim3<SrcT, DstT>(src, dst);
return;
case 4:
copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
copy_general_dim4<SrcT, DstT>(src, dst);
return;
}
auto src_ptr = src.data<SrcT>() + i_offset;
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) {
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
size_t src_elem = elem_to_loc(i, src.shape(), src.strides());
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
}
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
return copy_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
inline void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
return copy_general<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
}
template <typename SrcT, typename DstT, typename stride_t, int D>
template <typename SrcT, typename DstT, int D>
inline void copy_general_general_dims(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
size_t offset_src,
size_t offset_dst) {
if constexpr (D > 1) {
int axis = src.ndim() - D;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
auto stride_src = src.strides()[axis];
auto stride_dst = dst.strides()[axis];
auto N = src.shape(axis);
for (int i = 0; i < N; i++) {
copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
i_offset += stride_src;
o_offset += stride_dst;
copy_general_general_dims<SrcT, DstT, D - 1>(
src, dst, offset_src, offset_dst);
offset_src += stride_src;
offset_dst += stride_dst;
}
} else {
int axis = src.ndim() - 1;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
DstT* dst_ptr = dst.data<DstT>() + o_offset;
auto stride_src = src.strides()[axis];
auto stride_dst = dst.strides()[axis];
auto N = src.shape(axis);
const SrcT* src_ptr = src.data<SrcT>() + offset_src;
DstT* dst_ptr = dst.data<DstT>() + offset_dst;
for (int i = 0; i < N; i++) {
*dst_ptr = static_cast<DstT>(*src_ptr);
src_ptr += stride_src;
@@ -223,56 +148,37 @@ inline void copy_general_general_dims(
}
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
template <typename SrcT, typename DstT>
void copy_general_general(const array& src, array& dst) {
switch (src.ndim()) {
case 1:
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
copy_general_general_dims<SrcT, DstT, 1>(src, dst, 0, 0);
return;
case 2:
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
copy_general_general_dims<SrcT, DstT, 2>(src, dst, 0, 0);
return;
case 3:
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
copy_general_general_dims<SrcT, DstT, 3>(src, dst, 0, 0);
return;
case 4:
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
copy_general_general_dims<SrcT, DstT, 4>(src, dst, 0, 0);
return;
case 5:
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
copy_general_general_dims<SrcT, DstT, 5>(src, dst, 0, 0);
return;
}
int size = std::accumulate(
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
src.shape().begin() - 5, src.shape().end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) {
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
size_t offset_src = elem_to_loc(i, src.shape(), src.strides());
size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides());
copy_general_general_dims<SrcT, DstT, 5>(src, dst, offset_src, offset_dst);
}
}
template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) {
return copy_general_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
}
template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
void copy(const array& src, array& dst, CopyType ctype) {
switch (ctype) {
case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst);
@@ -281,103 +187,54 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
copy_vector<SrcT, DstT>(src, dst);
return;
case CopyType::General:
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
copy_general<SrcT, DstT>(src, dst);
return;
case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
copy_general_general<SrcT, DstT>(src, dst);
}
}
template <typename SrcT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
template <typename SrcT>
void copy(const array& src, array& dst, CopyType ctype) {
switch (dst.dtype()) {
case bool_:
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, bool>(src, dst, ctype);
break;
case uint8:
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint8_t>(src, dst, ctype);
break;
case uint16:
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint16_t>(src, dst, ctype);
break;
case uint32:
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint32_t>(src, dst, ctype);
break;
case uint64:
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint64_t>(src, dst, ctype);
break;
case int8:
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int8_t>(src, dst, ctype);
break;
case int16:
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int16_t>(src, dst, ctype);
break;
case int32:
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int32_t>(src, dst, ctype);
break;
case int64:
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int64_t>(src, dst, ctype);
break;
case float16:
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, float16_t>(src, dst, ctype);
break;
case float32:
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, float>(src, dst, ctype);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, bfloat16_t>(src, dst, ctype);
break;
case complex64:
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
}
}
template <typename... Args>
inline void copy_inplace_dispatch(
const array& src,
array& dst,
CopyType ctype,
Args&&... args) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint8:
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint16:
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint32:
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint64:
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int8:
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int16:
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int32:
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int64:
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float16:
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float32:
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case complex64:
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, complex64_t>(src, dst, ctype);
break;
}
}
@@ -385,23 +242,58 @@ inline void copy_inplace_dispatch(
} // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) {
return copy_inplace_dispatch(src, dst, ctype);
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype);
break;
case uint8:
copy<uint8_t>(src, dst, ctype);
break;
case uint16:
copy<uint16_t>(src, dst, ctype);
break;
case uint32:
copy<uint32_t>(src, dst, ctype);
break;
case uint64:
copy<uint64_t>(src, dst, ctype);
break;
case int8:
copy<int8_t>(src, dst, ctype);
break;
case int16:
copy<int16_t>(src, dst, ctype);
break;
case int32:
copy<int32_t>(src, dst, ctype);
break;
case int64:
copy<int64_t>(src, dst, ctype);
break;
case float16:
copy<float16_t>(src, dst, ctype);
break;
case float32:
copy<float>(src, dst, ctype);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype);
break;
case complex64:
copy<complex64_t>(src, dst, ctype);
break;
}
}
void copy(const array& src, array& dst, CopyType ctype) {
// Allocate the output
switch (ctype) {
case CopyType::Vector:
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
dst.copy_shared_buffer(src);
} else {
auto size = src.data_size();
dst.set_data(
allocator::malloc_or_wait(size * dst.itemsize()),
size,
src.strides(),
src.flags());
}
dst.set_data(
allocator::malloc_or_wait(src.data_size() * dst.itemsize()),
src.data_size(),
src.strides(),
src.flags());
break;
case CopyType::Scalar:
case CopyType::General:
@@ -415,62 +307,4 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype);
}
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
template <>
void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<int64_t>& i_strides,
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#pragma once
@@ -26,15 +26,4 @@ enum class CopyType {
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core

View File

@@ -1,13 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#include <vecLib/cblas_new.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
@@ -41,24 +39,18 @@ DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BlockMaskedMM)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(NumberOfElements)
DEFAULT(Remainder)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(Expm1)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Full)
@@ -82,7 +74,6 @@ DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT(Power)
DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce)
@@ -90,34 +81,32 @@ DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Sin)
DEFAULT(Sinh)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
DEFAULT(Square)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
DEFAULT(Subtract)
DEFAULT_MULTI(SVD)
DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT_MULTI(DivMod)
namespace {
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[Matmul::eval_cpu] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
inline void matmul_common_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
@@ -135,17 +124,9 @@ inline void matmul_common_general(
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
@@ -154,41 +135,16 @@ inline void matmul_common_general(
M,
N,
K,
alpha, // alpha
1.0f, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
0.0f, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[Matmul::eval_cpu] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_common_general(inputs[0], inputs[1], out);
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy(c, out, ctype);
return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core

11
mlx/backend/common/erf.h Normal file
View File

@@ -0,0 +1,11 @@
// Copyright © 2023 Apple Inc.
namespace mlx::core {
/* Approximation to the inverse error function.
* Based on code from:
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
*/
float erfinv(float a);
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More