mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:33:05 +08:00
Compare commits
82 Commits
dynamic_re
...
cpp20
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4515866024 | ||
![]() |
6fe2b82926 | ||
![]() |
c75b5e9d19 | ||
![]() |
6f12eda549 | ||
![]() |
a541fe9312 | ||
![]() |
2bdd20f257 | ||
![]() |
aa7b9688ce | ||
![]() |
0a41393dba | ||
![]() |
e300a01f4a | ||
![]() |
f288db8d34 | ||
![]() |
33421c1dd3 | ||
![]() |
5cc5201914 | ||
![]() |
252e423e81 | ||
![]() |
a4a2764a52 | ||
![]() |
ab8e832c18 | ||
![]() |
1ce0c0fcb0 | ||
![]() |
657f466402 | ||
![]() |
c7b0300af5 | ||
![]() |
da8c885784 | ||
![]() |
1ccaf80575 | ||
![]() |
ec36bfa317 | ||
![]() |
b8f76f717a | ||
![]() |
d1766f2c70 | ||
![]() |
516ded618b | ||
![]() |
c9c81d0584 | ||
![]() |
545f84d905 | ||
![]() |
d5ec172c95 | ||
![]() |
25b3a3e541 | ||
![]() |
058d6ce683 | ||
![]() |
eab93985b8 | ||
![]() |
b51d70a83c | ||
![]() |
259025100e | ||
![]() |
c9d30aa6ac | ||
![]() |
8544b42007 | ||
![]() |
6fa0501387 | ||
![]() |
ae69cb15e9 | ||
![]() |
a64a8dfe45 | ||
![]() |
491fa95b1f | ||
![]() |
92ec632ad5 | ||
![]() |
8ecdfb718b | ||
![]() |
4ba0c24a8f | ||
![]() |
935c8c4bb1 | ||
![]() |
88f993da38 | ||
![]() |
ebfe64b92d | ||
![]() |
0308e9af71 | ||
![]() |
c3628eea49 | ||
![]() |
e03f0372b1 | ||
![]() |
f17536af9c | ||
![]() |
ed4ec81bca | ||
![]() |
7480059306 | ||
![]() |
8bae22b0fa | ||
![]() |
49c34c4161 | ||
![]() |
5548fcc96d | ||
![]() |
070bd433ab | ||
![]() |
c8fb54951a | ||
![]() |
f110357aaa | ||
![]() |
a6b426422e | ||
![]() |
d03c01dfbc | ||
![]() |
a82996e9fb | ||
![]() |
af5a614aad | ||
![]() |
f9640e049d | ||
![]() |
4768c61b57 | ||
![]() |
dfccd17ab9 | ||
![]() |
635117c5d4 | ||
![]() |
50f3535693 | ||
![]() |
9111999af3 | ||
![]() |
6bd28d246e | ||
![]() |
4d595a2a39 | ||
![]() |
3a21f61772 | ||
![]() |
4e1e9520e1 | ||
![]() |
0bf19037ca | ||
![]() |
f3dfa36a3a | ||
![]() |
4f9b60dd53 | ||
![]() |
f76a49e555 | ||
![]() |
310ad8d9db | ||
![]() |
56db268f47 | ||
![]() |
92ab6bdeb8 | ||
![]() |
0070e360a1 | ||
![]() |
9df8fed046 | ||
![]() |
a59fae040f | ||
![]() |
29a620cab2 | ||
![]() |
87d7a2520e |
@@ -24,7 +24,7 @@ jobs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "15.2.0"
|
||||
xcode: "16.0.0"
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
@@ -70,8 +70,8 @@ jobs:
|
||||
git push -f origin gh-pages
|
||||
|
||||
linux_build_and_test:
|
||||
docker:
|
||||
- image: cimg/python:3.9
|
||||
machine:
|
||||
image: ubuntu-2404:2024.11.1
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
@@ -84,30 +84,33 @@ jobs:
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.2.0
|
||||
pip install numpy
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y python3.9 python3.9-distutils python3.9-dev
|
||||
python3.9 -m pip install --upgrade cmake
|
||||
python3.9 -m pip install nanobind==2.4.0
|
||||
python3.9 -m pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install libopenblas-dev liblapacke-dev openmpi-bin libopenmpi-dev
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF -DPython_EXECUTABLE=/usr/bin/python3.9" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
python3.9 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF -DPython_EXECUTABLE=/usr/bin/python3.9" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py develop
|
||||
python3.9 setup.py develop
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python3.9 -m pip install typing_extensions
|
||||
python3.9 setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
python3.9 -m unittest discover python/tests -v
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
@@ -122,7 +125,10 @@ jobs:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
default: "16.0.0"
|
||||
deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
@@ -137,7 +143,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.2.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
@@ -146,7 +152,9 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS=-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >> \
|
||||
pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
@@ -173,7 +181,11 @@ jobs:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source env/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||
mkdir -p build
|
||||
cd build/
|
||||
cmake .. \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >>
|
||||
make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: |
|
||||
@@ -188,14 +200,15 @@ jobs:
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
-DMLX_METAL_JIT=ON \
|
||||
-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >>
|
||||
make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON -DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >>" \
|
||||
pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
@@ -208,7 +221,10 @@ jobs:
|
||||
default: "3.9"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
default: "16.0.0"
|
||||
deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
@@ -226,7 +242,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.2.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
@@ -237,6 +253,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS=-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >> \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
@@ -250,6 +267,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
<< parameters.build_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS=-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >> \
|
||||
python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
@@ -291,7 +309,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.2.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
@@ -330,9 +348,10 @@ workflows:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
xcode_version: ["16.0.0"]
|
||||
deployment_target: ["", "13.5"]
|
||||
- linux_build_and_test
|
||||
- build_documentation
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
@@ -350,7 +369,8 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
xcode_version: ["16.0.0"]
|
||||
deployment_target: ["", "13.5"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
- build_documentation:
|
||||
filters:
|
||||
@@ -374,7 +394,8 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
xcode_version: ["16.0.0"]
|
||||
deployment_target: ["", "13.5"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
@@ -387,7 +408,8 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
xcode_version: ["16.0.0"]
|
||||
deployment_target: ["", "13.5"]
|
||||
weekly_build:
|
||||
when:
|
||||
and:
|
||||
@@ -398,7 +420,8 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
xcode_version: ["16.0.0"]
|
||||
deployment_target: ["", "13.5"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
linux_test_release:
|
||||
when:
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@@ -76,6 +76,9 @@ build/
|
||||
*.out
|
||||
*.app
|
||||
|
||||
# Debug symbols
|
||||
*.pdb
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
.DS_Store
|
||||
|
@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||
|
@@ -1,10 +1,10 @@
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
cmake_minimum_required(VERSION 3.25)
|
||||
|
||||
project(mlx LANGUAGES C CXX)
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||
@@ -20,12 +20,14 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.21.1)
|
||||
set(MLX_VERSION 0.22.0)
|
||||
endif()
|
||||
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
||||
@@ -93,8 +95,7 @@ elseif(MLX_BUILD_METAL)
|
||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||
|
||||
set(METAL_CPP_URL
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
|
||||
)
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
||||
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
@@ -113,16 +114,55 @@ elseif(MLX_BUILD_METAL)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
if(MSVC)
|
||||
# GGUF does not build with MSVC.
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
# There is no prebuilt OpenBLAS distribution for MSVC.
|
||||
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
|
||||
endif()
|
||||
# Windows implementation of dlfcn.h APIs.
|
||||
FetchContent_Declare(
|
||||
dlfcn-win32
|
||||
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
||||
GIT_TAG v1.4.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
block()
|
||||
set(BUILD_SHARED_LIBS OFF)
|
||||
FetchContent_MakeAvailable(dlfcn-win32)
|
||||
endblock()
|
||||
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
|
||||
target_link_libraries(mlx PRIVATE dl)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if(ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
|
||||
# Download and build OpenBLAS from source code.
|
||||
FetchContent_Declare(
|
||||
openblas
|
||||
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
|
||||
GIT_TAG v0.3.28
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(BUILD_STATIC_LIBS ON) # link statically
|
||||
set(NOFORTRAN ON) # msvc has no fortran compiler
|
||||
FetchContent_MakeAvailable(openblas)
|
||||
target_link_libraries(mlx PRIVATE openblas)
|
||||
target_include_directories(
|
||||
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
|
||||
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
|
||||
else()
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
@@ -140,7 +180,7 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old
|
||||
# version of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
@@ -153,14 +193,7 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
|
||||
|
||||
if(WIN32)
|
||||
find_package(dlfcn-win32 REQUIRED)
|
||||
message(STATUS "dlfcn-win32 lib " ${dlfcn-win32_LIBRARIES})
|
||||
message(STATUS "dlfcn-win32 include " ${dlfcn-win32_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${dlfcn-win32_LIBRARIES})
|
||||
endif()
|
||||
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
@@ -190,14 +223,6 @@ target_include_directories(
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||
|
||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
message(STATUS "Building Python bindings.")
|
||||
find_package(
|
||||
@@ -207,8 +232,7 @@ if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
OUTPUT_VARIABLE nanobind_ROOT)
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||
endif()
|
||||
|
@@ -5,35 +5,35 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_value_and_grad() {
|
||||
auto x = ones({200, 1000});
|
||||
eval(x);
|
||||
auto fn = [](array x) {
|
||||
auto x = mx::ones({200, 1000});
|
||||
mx::eval(x);
|
||||
auto fn = [](mx::array x) {
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
x = log(exp(x));
|
||||
x = mx::log(mx::exp(x));
|
||||
}
|
||||
return sum(x);
|
||||
return mx::sum(x);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(fn);
|
||||
auto grad_fn = mx::grad(fn);
|
||||
auto independent_value_and_grad = [&]() {
|
||||
auto value = fn(x);
|
||||
auto dfdx = grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
};
|
||||
TIME(independent_value_and_grad);
|
||||
|
||||
auto value_and_grad_fn = value_and_grad(fn);
|
||||
auto value_and_grad_fn = mx::value_and_grad(fn);
|
||||
auto combined_value_and_grad = [&]() {
|
||||
auto [value, dfdx] = value_and_grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
};
|
||||
TIME(combined_value_and_grad);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_value_and_grad();
|
||||
}
|
||||
|
@@ -4,21 +4,21 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_add_op() {
|
||||
std::vector<int> sizes(1, 1);
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
sizes.push_back(10 * sizes.back());
|
||||
}
|
||||
set_default_device(Device::cpu);
|
||||
set_default_device(mx::Device::cpu);
|
||||
for (auto size : sizes) {
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
std::cout << "Size " << size << std::endl;
|
||||
TIMEM("cpu", add, a, b, Device::cpu);
|
||||
TIMEM("gpu", add, a, b, Device::gpu);
|
||||
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
|
||||
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -6,105 +6,105 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_irregular_binary_ops_1D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
b = slice(b, {0}, {size}, {step});
|
||||
TIMEM("1D strided", add, a, b, device);
|
||||
TIMEM("1D strided", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_2D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
auto a = random::uniform({size, size});
|
||||
auto b = random::uniform({size, size});
|
||||
eval(a, b);
|
||||
TIMEM("2D regular", add, a, b, device);
|
||||
auto a = mx::random::uniform({size, size});
|
||||
auto b = mx::random::uniform({size, size});
|
||||
mx::eval(a, b);
|
||||
TIMEM("2D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b);
|
||||
eval(b);
|
||||
TIMEM("2D transpose", add, a, b, device);
|
||||
b = mx::transpose(b);
|
||||
mx::eval(b);
|
||||
TIMEM("2D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({size});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 0", add, a, b, device);
|
||||
b = mx::random::uniform({size});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
|
||||
|
||||
b = reshape(b, {size, 1});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 1", add, a, b, device);
|
||||
b = mx::reshape(b, {size, 1});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_3D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int d0 = 32;
|
||||
int d1 = 512;
|
||||
int d2 = 512;
|
||||
auto a = random::uniform({d0, d1, d2});
|
||||
auto b = random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", add, a, b, device);
|
||||
auto a = mx::random::uniform({d0, d1, d2});
|
||||
auto b = mx::random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 2, 1});
|
||||
TIMEM("3D transpose", add, a, b, device);
|
||||
b = mx::transpose(b, {0, 2, 1});
|
||||
TIMEM("3D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", add, a, b, device);
|
||||
b = mx::random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", add, a, b, device);
|
||||
b = mx::random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", add, a, b, device);
|
||||
b = mx::random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
|
||||
b = mx::random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
|
||||
b = mx::random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
|
||||
b = mx::random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
auto a = random::uniform(shape);
|
||||
auto b = random::uniform(shape);
|
||||
auto a = mx::random::uniform(shape);
|
||||
auto b = mx::random::uniform(shape);
|
||||
|
||||
TIMEM("4D regular", add, a, b, device);
|
||||
TIMEM("4D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D transpose", add, a, b, device);
|
||||
b = mx::transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
std::string om = "4D broadcast dims ";
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = 1;
|
||||
b = random::uniform(shape);
|
||||
b = mx::random::uniform(shape);
|
||||
std::ostringstream msg;
|
||||
msg << om << i;
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
|
||||
for (int j = i + 1; j < shape.size(); ++j) {
|
||||
shape[j] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
shape[j] = a.shape(j);
|
||||
|
||||
for (int k = j + 1; k < shape.size(); ++k) {
|
||||
shape[k] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j << ", " << k;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
shape[k] = a.shape(k);
|
||||
}
|
||||
}
|
||||
@@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() {
|
||||
}
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape;
|
||||
auto reshape_fn = [&shape, device](const array& a) {
|
||||
return reshape(a, shape, device);
|
||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||
return mx::reshape(a, shape, device);
|
||||
};
|
||||
|
||||
int size = 64;
|
||||
int d = 2 * size;
|
||||
|
||||
auto a = random::uniform({d, d, d});
|
||||
auto a = mx::random::uniform({d, d, d});
|
||||
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D contiguous", reshape_fn, a);
|
||||
|
||||
a = transpose(a);
|
||||
a = mx::transpose(a);
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose", reshape_fn, a);
|
||||
TIMEM("3D mx::transpose", reshape_fn, a);
|
||||
|
||||
a = transpose(a, {1, 2, 0});
|
||||
a = mx::transpose(a, {1, 2, 0});
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose dims 1 2", reshape_fn, a);
|
||||
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
||||
}
|
||||
|
||||
void time_irregular_astype_1D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto a = mx::random::uniform({size});
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
TIMEM("1D strided", astype, a, int32, device);
|
||||
TIMEM("1D strided", mx::astype, a, mx::int32, device);
|
||||
}
|
||||
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
|
||||
auto a = random::uniform(shape);
|
||||
TIMEM("2D regular", astype, a, int32, device);
|
||||
auto a = mx::random::uniform(shape);
|
||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = transpose(a);
|
||||
TIMEM("2D transpose", astype, a, int32, device);
|
||||
a = mx::transpose(a);
|
||||
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", astype, a, int32, device);
|
||||
a = mx::broadcast_to(mx::random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", astype, a, int32, device);
|
||||
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc > 1) {
|
||||
bool use_gpu = !strcmp(argv[1], "gpu");
|
||||
set_default_device(use_gpu ? Device::gpu : Device::cpu);
|
||||
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
|
||||
}
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_irregular_binary_ops_1D();
|
||||
time_irregular_binary_ops_2D();
|
||||
time_irregular_binary_ops_3D();
|
||||
|
@@ -3,20 +3,20 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_creation_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto full_fp32 = [&]() { return full(shape, 3.3f); };
|
||||
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
|
||||
TIME(full_fp32);
|
||||
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
|
||||
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
|
||||
TIME(zeros_fp32);
|
||||
auto ones_fp32 = [&]() { return ones(shape, float32); };
|
||||
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
|
||||
TIME(ones_fp32);
|
||||
|
||||
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
|
||||
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
|
||||
TIME(arange_fp32);
|
||||
}
|
||||
|
||||
@@ -24,194 +24,196 @@ void time_type_conversions() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
|
||||
auto a = zeros(shape, float32);
|
||||
eval(a);
|
||||
TIMEM("float32 to int32", astype, a, int32, device);
|
||||
TIMEM("float32 to uint32", astype, a, uint32, device);
|
||||
auto a = mx::zeros(shape, mx::float32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
|
||||
a = zeros(shape, int32);
|
||||
eval(a);
|
||||
TIMEM("int32 to float32", astype, a, float32, device);
|
||||
a = mx::zeros(shape, mx::int32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
|
||||
|
||||
a = zeros(shape, bool_);
|
||||
eval(a);
|
||||
TIMEM("bool to float32", astype, a, float32, device);
|
||||
TIMEM("bool to int32", astype, a, int32, device);
|
||||
TIMEM("bool to uint32", astype, a, uint32, device);
|
||||
a = mx::zeros(shape, mx::bool_);
|
||||
mx::eval(a);
|
||||
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
|
||||
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
}
|
||||
|
||||
void time_random_generation() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
|
||||
auto uniform = [&]() { return random::uniform({M, N}, float32); };
|
||||
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
|
||||
TIME(uniform);
|
||||
auto normal = [&]() { return random::normal({M, N}, float32); };
|
||||
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
|
||||
TIME(normal);
|
||||
}
|
||||
|
||||
void time_unary_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
|
||||
auto a = random::normal({M, N});
|
||||
eval(a);
|
||||
auto a = mx::random::normal({M, N});
|
||||
mx::eval(a);
|
||||
TIME(mlx::core::abs, a, device);
|
||||
TIME(negative, a, device);
|
||||
TIME(sign, a, device);
|
||||
TIME(square, a, device);
|
||||
TIME(mx::negative, a, device);
|
||||
TIME(mx::sign, a, device);
|
||||
TIME(mx::square, a, device);
|
||||
TIME(mlx::core::sqrt, a, device);
|
||||
TIME(rsqrt, a, device);
|
||||
TIME(mx::rsqrt, a, device);
|
||||
TIME(mlx::core::exp, a, device);
|
||||
|
||||
a = random::uniform({M, N});
|
||||
a = mx::random::uniform({M, N});
|
||||
TIME(mlx::core::log, a, device);
|
||||
}
|
||||
|
||||
void time_binary_ops() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto condition = random::randint(0, 2, {M, N, K});
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
auto condition = mx::random::randint(0, 2, {M, N, K});
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
|
||||
TIME(add, a, b, device);
|
||||
TIME(subtract, a, b, device);
|
||||
TIME(multiply, a, b, device);
|
||||
TIME(divide, a, b, device);
|
||||
TIME(maximum, a, b, device);
|
||||
TIME(minimum, a, b, device);
|
||||
TIME(where, condition, a, b, device);
|
||||
TIME(mx::add, a, b, device);
|
||||
TIME(mx::subtract, a, b, device);
|
||||
TIME(mx::multiply, a, b, device);
|
||||
TIME(mx::divide, a, b, device);
|
||||
TIME(mx::maximum, a, b, device);
|
||||
TIME(mx::minimum, a, b, device);
|
||||
TIME(mx::where, condition, a, b, device);
|
||||
|
||||
condition = array({true});
|
||||
b = random::uniform({1});
|
||||
eval(b);
|
||||
TIMEM("scalar", add, a, b, device);
|
||||
TIMEM("vector-scalar", subtract, a, b, device);
|
||||
TIMEM("scalar-vector", subtract, b, a, device);
|
||||
TIMEM("scalar", multiply, a, b, device);
|
||||
TIMEM("vector-scalar", divide, a, b, device);
|
||||
TIMEM("scalar-vector", divide, b, a, device);
|
||||
TIMEM("scalar-vector", where, condition, a, b, device);
|
||||
condition = mx::array({true});
|
||||
b = mx::random::uniform({1});
|
||||
mx::eval(b);
|
||||
TIMEM("scalar", mx::add, a, b, device);
|
||||
TIMEM("vector-scalar", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-vector", mx::subtract, b, a, device);
|
||||
TIMEM("scalar", mx::multiply, a, b, device);
|
||||
TIMEM("vector-scalar", mx::divide, a, b, device);
|
||||
TIMEM("scalar-vector", mx::divide, b, a, device);
|
||||
TIMEM("scalar-vector", mx::where, condition, a, b, device);
|
||||
|
||||
condition = broadcast_to(array({true}), {1000, 100});
|
||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
|
||||
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
|
||||
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
mx::eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
|
||||
}
|
||||
|
||||
void time_strided_ops() {
|
||||
int M = 50, N = 50, O = 50, P = 50;
|
||||
auto a = random::uniform({M, N, O, P});
|
||||
auto b = random::uniform({M, N, O, P});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIMEM("non-strided", add, a, b, device);
|
||||
a = transpose(a, {1, 0, 2, 3});
|
||||
b = transpose(b, {3, 2, 0, 1});
|
||||
eval(a, b);
|
||||
TIMEM("strided", add, a, b, device);
|
||||
auto a = mx::random::uniform({M, N, O, P});
|
||||
auto b = mx::random::uniform({M, N, O, P});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIMEM("non-strided", mx::add, a, b, device);
|
||||
a = mx::transpose(a, {1, 0, 2, 3});
|
||||
b = mx::transpose(b, {3, 2, 0, 1});
|
||||
mx::eval(a, b);
|
||||
TIMEM("strided", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_comparisons() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(equal, a, b, device);
|
||||
TIME(greater, a, b, device);
|
||||
TIME(greater_equal, a, b, device);
|
||||
TIME(less, a, b, device);
|
||||
TIME(less_equal, a, b, device);
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::equal, a, b, device);
|
||||
TIME(mx::greater, a, b, device);
|
||||
TIME(mx::greater_equal, a, b, device);
|
||||
TIME(mx::less, a, b, device);
|
||||
TIME(mx::less_equal, a, b, device);
|
||||
}
|
||||
|
||||
void time_matvec() {
|
||||
int M = 2000, N = 200;
|
||||
auto a = random::uniform({M, N});
|
||||
auto b = random::uniform({N});
|
||||
auto c = random::uniform({M});
|
||||
eval(a, b, c);
|
||||
auto matvec = [&]() { return matmul(a, b); };
|
||||
auto a = mx::random::uniform({M, N});
|
||||
auto b = mx::random::uniform({N});
|
||||
auto c = mx::random::uniform({M});
|
||||
mx::eval(a, b, c);
|
||||
auto matvec = [&]() { return mx::matmul(a, b); };
|
||||
TIME(matvec);
|
||||
|
||||
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
|
||||
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
|
||||
TIME(matvec_transpose);
|
||||
}
|
||||
|
||||
void time_matmul() {
|
||||
int M = 1000, N = 1000, K = 1000;
|
||||
auto a = random::uniform({M, K});
|
||||
auto b = random::uniform({K, N});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(matmul, a, b, device);
|
||||
auto a = mx::random::uniform({M, K});
|
||||
auto b = mx::random::uniform({K, N});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::matmul, a, b, device);
|
||||
|
||||
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
|
||||
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
|
||||
TIME(transpose_matmul);
|
||||
}
|
||||
|
||||
void time_reductions() {
|
||||
auto a = random::normal({10000, 1000});
|
||||
eval(a);
|
||||
auto sum_all = [&a]() { return sum(a, false); };
|
||||
auto a = mx::random::normal({10000, 1000});
|
||||
mx::eval(a);
|
||||
auto sum_all = [&a]() { return mx::sum(a, false); };
|
||||
TIME(sum_all);
|
||||
|
||||
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
|
||||
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
|
||||
TIME(sum_along_0);
|
||||
|
||||
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
|
||||
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
|
||||
TIME(sum_along_1);
|
||||
|
||||
auto prod_all = [&a]() { return prod(a, false); };
|
||||
auto prod_all = [&a]() { return mx::prod(a, false); };
|
||||
TIME(prod_all);
|
||||
|
||||
auto all_true = [&a]() { return all(a, false); };
|
||||
auto all_true = [&a]() { return mx::all(a, false); };
|
||||
TIME(all_true);
|
||||
|
||||
auto all_along_0 = [&a]() { return all(a, 0, false); };
|
||||
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
|
||||
TIME(all_along_0);
|
||||
|
||||
auto all_along_1 = [&a]() { return all(a, 1, false); };
|
||||
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
|
||||
TIME(all_along_1);
|
||||
|
||||
auto any_true = [&a]() { return any(a, false); };
|
||||
auto any_true = [&a]() { return mx::any(a, false); };
|
||||
TIME(any_true);
|
||||
|
||||
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
|
||||
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
|
||||
TIME(argmin_along_0);
|
||||
|
||||
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
auto a = random::normal({1000, 768});
|
||||
eval(a);
|
||||
auto indices = random::randint(0, 1000, {256});
|
||||
eval(indices);
|
||||
auto a = mx::random::normal({1000, 768});
|
||||
mx::eval(a);
|
||||
auto indices = mx::random::randint(0, 1000, {256});
|
||||
mx::eval(indices);
|
||||
|
||||
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
|
||||
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
|
||||
TIME(embedding_lookup);
|
||||
|
||||
indices = random::randint(0, 768 * 1000, {256 * 768});
|
||||
eval(indices);
|
||||
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
|
||||
mx::eval(indices);
|
||||
|
||||
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
|
||||
auto single_element_lookup = [&a, &indices]() {
|
||||
return mx::take(a, indices);
|
||||
};
|
||||
TIME(single_element_lookup);
|
||||
|
||||
indices = random::randint(0, 1000, {256});
|
||||
auto updates = random::normal({256, 1, 768});
|
||||
eval(indices, updates);
|
||||
indices = mx::random::randint(0, 1000, {256});
|
||||
auto updates = mx::random::normal({256, 1, 768});
|
||||
mx::eval(indices, updates);
|
||||
|
||||
auto embedding_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@@ -223,10 +225,10 @@ void time_gather_scatter() {
|
||||
};
|
||||
TIME(embedding_add);
|
||||
|
||||
a = reshape(a, {-1});
|
||||
indices = random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = random::normal({256 * 768, 1});
|
||||
eval(a, indices, updates);
|
||||
a = mx::reshape(a, {-1});
|
||||
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = mx::random::normal({256 * 768, 1});
|
||||
mx::eval(a, indices, updates);
|
||||
|
||||
auto single_element_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@@ -240,21 +242,21 @@ void time_gather_scatter() {
|
||||
}
|
||||
|
||||
void time_divmod() {
|
||||
auto a = random::normal({1000});
|
||||
auto b = random::normal({1000});
|
||||
eval({a, b});
|
||||
auto a = mx::random::normal({1000});
|
||||
auto b = mx::random::normal({1000});
|
||||
mx::eval({a, b});
|
||||
|
||||
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
|
||||
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
|
||||
TIME(divmod_fused);
|
||||
|
||||
auto divmod_separate = [&a, &b]() {
|
||||
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
|
||||
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
|
||||
};
|
||||
TIME(divmod_separate);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_creation_ops();
|
||||
time_type_conversions();
|
||||
time_unary_ops();
|
||||
|
@@ -12,7 +12,7 @@ dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
def attention(q, k, v, mask=None):
|
||||
def _sdpa(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
@@ -20,6 +20,9 @@ def attention(q, k, v):
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
if mask is not None:
|
||||
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
|
||||
s = mx.where(m, s, mx.finfo(s.dtype).min)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
@@ -29,9 +32,9 @@ def attention(q, k, v):
|
||||
return q
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
def sdpa(q, k, v, mask=None):
|
||||
for i in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
return q
|
||||
|
||||
|
||||
@@ -53,6 +56,26 @@ def time_self_attention_sdpa():
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa_with_mask():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mask = mx.full((L,), True)
|
||||
mask[L // 2 :] = False
|
||||
mx.eval(q, k, v, mask)
|
||||
|
||||
def sdpa_mask(*args):
|
||||
return sdpa(*args, mask=mask)
|
||||
|
||||
def attention_mask(*args):
|
||||
return attention(*args, mask=mask)
|
||||
|
||||
time_fn(attention_mask, q, k, v)
|
||||
time_fn(sdpa_mask, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
time_self_attention_sdpa_with_mask()
|
||||
|
121
docs/src/dev/mlx_in_cpp.rst
Normal file
121
docs/src/dev/mlx_in_cpp.rst
Normal file
@@ -0,0 +1,121 @@
|
||||
.. _mlx_in_cpp:
|
||||
|
||||
Using MLX in C++
|
||||
================
|
||||
|
||||
You can use MLX in a C++ project with CMake.
|
||||
|
||||
.. note::
|
||||
|
||||
This guide is based one the following `example using MLX in C++
|
||||
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
|
||||
|
||||
First install MLX:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U mlx
|
||||
|
||||
You can also install the MLX Python package from source or just the C++
|
||||
library. For more information see the :ref:`documentation on installing MLX
|
||||
<build_and_install>`.
|
||||
|
||||
Next make an example program in ``example.cpp``:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
auto x = mx::array({1, 2, 3});
|
||||
auto y = mx::array({1, 2, 3});
|
||||
std::cout << x + y << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
The next step is to setup a CMake file in ``CMakeLists.txt``:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(example LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
|
||||
Depending on how you installed MLX, you may need to tell CMake where to
|
||||
find it.
|
||||
|
||||
If you installed MLX with Python, then add the following to the CMake file:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
|
||||
If you installed the MLX C++ package to a system path, then CMake should be
|
||||
able to find it. If you installed it to a non-standard location or CMake can't
|
||||
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
set(MLX_ROOT "/path/to/mlx/")
|
||||
|
||||
Next, instruct CMake to find MLX:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
Finally, add the ``example.cpp`` program as an executable and link MLX.
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
add_executable(example example.cpp)
|
||||
target_link_libraries(example PRIVATE mlx)
|
||||
|
||||
You can build the example with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
|
||||
And run it with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./build/example
|
||||
|
||||
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
|
||||
|
||||
.. list-table:: Package Variables
|
||||
:widths: 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* - Variable
|
||||
- Description
|
||||
* - MLX_FOUND
|
||||
- ``True`` if MLX is found
|
||||
* - MLX_INCLUDE_DIRS
|
||||
- Include directory
|
||||
* - MLX_LIBRARIES
|
||||
- Libraries to link against
|
||||
* - MLX_CXX_FLAGS
|
||||
- Additional compiler flags
|
||||
* - MLX_BUILD_ACCELERATE
|
||||
- ``True`` if MLX was built with Accelerate
|
||||
* - MLX_BUILD_METAL
|
||||
- ``True`` if MLX was built with Metal
|
@@ -45,6 +45,7 @@ are the CPU and GPU.
|
||||
usage/numpy
|
||||
usage/distributed
|
||||
usage/using_streams
|
||||
usage/export
|
||||
|
||||
.. toctree::
|
||||
:caption: Examples
|
||||
@@ -61,6 +62,7 @@ are the CPU and GPU.
|
||||
python/array
|
||||
python/data_types
|
||||
python/devices_and_streams
|
||||
python/export
|
||||
python/ops
|
||||
python/random
|
||||
python/transforms
|
||||
@@ -86,3 +88,4 @@ are the CPU and GPU.
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
dev/custom_metal_kernels
|
||||
dev/mlx_in_cpp
|
||||
|
@@ -1,3 +1,5 @@
|
||||
.. _build_and_install:
|
||||
|
||||
Build and Install
|
||||
=================
|
||||
|
||||
@@ -53,7 +55,7 @@ Build Requirements
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
|
||||
- Xcode >= 15.0 and macOS SDK >= 14.0
|
||||
|
||||
.. note::
|
||||
|
@@ -66,3 +66,4 @@ documentation for more information. Use :func:`issubdtype` to determine if one
|
||||
Dtype
|
||||
DtypeCategory
|
||||
issubdtype
|
||||
finfo
|
||||
|
14
docs/src/python/export.rst
Normal file
14
docs/src/python/export.rst
Normal file
@@ -0,0 +1,14 @@
|
||||
.. _export:
|
||||
|
||||
Export Functions
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
export_function
|
||||
import_function
|
||||
exporter
|
||||
export_to_dot
|
@@ -89,6 +89,7 @@ Operations
|
||||
isneginf
|
||||
isposinf
|
||||
issubdtype
|
||||
kron
|
||||
left_shift
|
||||
less
|
||||
less_equal
|
||||
@@ -144,6 +145,8 @@ Operations
|
||||
sign
|
||||
sin
|
||||
sinh
|
||||
slice
|
||||
slice_update
|
||||
softmax
|
||||
sort
|
||||
split
|
||||
@@ -168,6 +171,7 @@ Operations
|
||||
tri
|
||||
tril
|
||||
triu
|
||||
unflatten
|
||||
var
|
||||
view
|
||||
where
|
||||
|
@@ -421,3 +421,77 @@ the most opportunity to optimize the computation graph:
|
||||
# Compiling the outer function is good to do as it will likely
|
||||
# be faster even though the inner functions are compiled
|
||||
fun = mx.compile(outer)
|
||||
|
||||
|
||||
|
||||
.. _shapeless_compile:
|
||||
|
||||
Shapeless Compilation
|
||||
---------------------
|
||||
|
||||
When the shape of an input to a compiled function changes, the function is
|
||||
recompiled. You can compile a function once and run it on inputs with
|
||||
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
|
||||
case changes to the shapes of the inputs do not cause the function to be
|
||||
recompiled.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.abs(x + y)
|
||||
|
||||
compiled_fun = mx.compile(fun, shapeless=True)
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(-2.0)
|
||||
|
||||
# Firt call compiles the function
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
# Second call with different shapes
|
||||
# does not recompile the function
|
||||
x = mx.array([1.0, -6.0])
|
||||
y = mx.array([-2.0, 3.0])
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
|
||||
Use shapeless compilations carefully. Since compilation is not triggered when
|
||||
shapes change, any graphs which are conditional on the input shapes will not
|
||||
work as expected. Shape-dependent computations are common and sometimes subtle
|
||||
to detect. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
return x.reshape(x.shape[0] * x.shape[1], -1)
|
||||
|
||||
compiled_fun = mx.compile(fun, shapeless=True)
|
||||
|
||||
x = mx.random.uniform(shape=(2, 3, 4))
|
||||
|
||||
out = compiled_fun(x)
|
||||
|
||||
x = mx.random.uniform(shape=(5, 5, 3))
|
||||
|
||||
# Error, can't reshape (5, 5, 3) to (6, -1)
|
||||
out = compiled_fun(x)
|
||||
|
||||
The second call to the ``compiled_fun`` fails because of the call to
|
||||
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
|
||||
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
return x.flatten(0, 1)
|
||||
|
||||
compiled_fun = mx.compile(fun, shapeless=True)
|
||||
|
||||
x = mx.random.uniform(shape=(2, 3, 4))
|
||||
|
||||
out = compiled_fun(x)
|
||||
|
||||
x = mx.random.uniform(shape=(5, 5, 3))
|
||||
|
||||
# Ok
|
||||
out = compiled_fun(x)
|
||||
|
@@ -141,12 +141,13 @@ everything else remaining the same.
|
||||
from mlx.utils import tree_map
|
||||
|
||||
def all_reduce_grads(grads):
|
||||
N = mx.distributed.init()
|
||||
N = mx.distributed.init().size()
|
||||
if N == 1:
|
||||
return grads
|
||||
return tree_map(
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads)
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads
|
||||
)
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
|
288
docs/src/usage/export.rst
Normal file
288
docs/src/usage/export.rst
Normal file
@@ -0,0 +1,288 @@
|
||||
.. _export_usage:
|
||||
|
||||
Exporting Functions
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX has an API to export and import functions to and from a file. This lets you
|
||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||
front-end (e.g. C++).
|
||||
|
||||
This guide walks through the basics of the MLX export API with some examples.
|
||||
To see the full list of functions check-out the :ref:`API documentation
|
||||
<export>`.
|
||||
|
||||
Basics of Exporting
|
||||
-------------------
|
||||
|
||||
Let's start with a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
mx.export_function("add.mlxfn", fun, x, y)
|
||||
|
||||
To export a function, provide sample input arrays that the function
|
||||
can be called with. The data doesn't matter, but the shapes and types of the
|
||||
arrays do. In the above example we exported ``fun`` with two ``float32``
|
||||
scalar arrays. We can then import the function and run it:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
add_fun = mx.import_function("add.mlxfn")
|
||||
|
||||
out, = add_fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints: array(3, dtype=float32)
|
||||
print(out)
|
||||
|
||||
out, = add_fun(mx.array(1.0), mx.array(3.0))
|
||||
# Prints: array(4, dtype=float32)
|
||||
print(out)
|
||||
|
||||
# Raises an exception
|
||||
add_fun(mx.array(1), mx.array(3.0))
|
||||
|
||||
# Raises an exception
|
||||
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
|
||||
|
||||
Notice the third and fourth calls to ``add_fun`` raise exceptions because the
|
||||
shapes and types of the inputs are different than the shapes and types of the
|
||||
example inputs we exported the function with.
|
||||
|
||||
Also notice that even though the original ``fun`` returns a single output
|
||||
array, the imported function always returns a tuple of one or more arrays.
|
||||
|
||||
The inputs to :func:`export_function` and to an imported function can be
|
||||
specified as variable positional arguments or as a tuple of arrays:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
# Both arguments to fun are positional
|
||||
mx.export_function("add.mlxfn", fun, x, y)
|
||||
|
||||
# Same as above
|
||||
mx.export_function("add.mlxfn", fun, (x, y))
|
||||
|
||||
imported_fun = mx.import_function("add.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_fun(x, y)
|
||||
|
||||
# Also ok
|
||||
out, = imported_fun((x, y))
|
||||
|
||||
You can pass example inputs to functions as positional or keyword arguments. If
|
||||
you use keyword arguments to export the function, then you have to use the same
|
||||
keyword arguments when calling the imported function.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
# One argument to fun is positional, the other is a kwarg
|
||||
mx.export_function("add.mlxfn", fun, x, y=y)
|
||||
|
||||
imported_fun = mx.import_function("add.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_fun(x, y=y)
|
||||
|
||||
# Also ok
|
||||
out, = imported_fun((x,), {"y": y})
|
||||
|
||||
# Raises since the keyword argument is missing
|
||||
out, = imported_fun(x, y)
|
||||
|
||||
# Raises since the keyword argument has the wrong key
|
||||
out, = imported_fun(x, z=y)
|
||||
|
||||
|
||||
Exporting Modules
|
||||
-----------------
|
||||
|
||||
An :obj:`mlx.nn.Module` can be exported with or without the parameters included
|
||||
in the exported function. Here's an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = nn.Linear(4, 4)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
def call(x):
|
||||
return model(x)
|
||||
|
||||
mx.export_function("model.mlxfn", call, mx.zeros(4))
|
||||
|
||||
In the above example, the :obj:`mlx.nn.Linear` module is exported. Its
|
||||
parameters are also saved to the ``model.mlxfn`` file.
|
||||
|
||||
.. note::
|
||||
|
||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||
they are evaluated. The computation graph that gets exported will include
|
||||
the computation that produces enclosed inputs.
|
||||
|
||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||
exported function would include the random initialization of the
|
||||
:obj:`mlx.nn.Module` parameters.
|
||||
|
||||
If you only want to export the ``Module.__call__`` function without the
|
||||
parameters, pass them as inputs to the ``call`` wrapper:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = nn.Linear(4, 4)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
def call(x, **params):
|
||||
# Set the model's parameters to the input parameters
|
||||
model.update(tree_unflatten(list(params.items())))
|
||||
return model(x)
|
||||
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||
|
||||
|
||||
Shapeless Exports
|
||||
-----------------
|
||||
|
||||
Just like :func:`compile`, functions can also be exported for dynamically shaped
|
||||
inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter`
|
||||
to export a function which can be used for inputs with variable shapes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
|
||||
imported_abs = mx.import_function("fun.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_abs(mx.array(-1.0))
|
||||
|
||||
# Also ok
|
||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||
|
||||
With ``shapeless=False`` (which is the default), the second call to
|
||||
``imported_abs`` would raise an exception with a shape mismatch.
|
||||
|
||||
Shapeless exporting works the same as shapeless compilation and should be
|
||||
used carefully. See the :ref:`documentation on shapeless compilation
|
||||
<shapeless_compile>` for more information.
|
||||
|
||||
Exporting Multiple Traces
|
||||
-------------------------
|
||||
|
||||
In some cases, functions build different computation graphs for different
|
||||
input arguments. A simple way to manage this is to export to a new file with
|
||||
each set of inputs. This is a fine option in many cases. But it can be
|
||||
suboptimal if the exported functions have a large amount of duplicate constant
|
||||
data (for example the parameters of a :obj:`mlx.nn.Module`).
|
||||
|
||||
The export API in MLX lets you export multiple traces of the same function to
|
||||
a single file by creating an exporting context manager with :func:`exporter`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y=None):
|
||||
constant = mx.array(3.0)
|
||||
if y is not None:
|
||||
x += y
|
||||
return x + constant
|
||||
|
||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||
exporter(mx.array(1.0))
|
||||
exporter(mx.array(1.0), y=mx.array(0.0))
|
||||
|
||||
imported_function = mx.import_function("fun.mlxfn")
|
||||
|
||||
# Call the function with y=None
|
||||
out, = imported_function(mx.array(1.0))
|
||||
print(out)
|
||||
|
||||
# Call the function with y specified
|
||||
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
|
||||
print(out)
|
||||
|
||||
In the above example the function constant data, (i.e. ``constant``), is only
|
||||
saved once.
|
||||
|
||||
Transformations with Imported Functions
|
||||
---------------------------------------
|
||||
|
||||
Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work
|
||||
on imported functions just like regular Python functions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
return mx.sin(x)
|
||||
|
||||
x = mx.array(0.0)
|
||||
mx.export_function("sine.mlxfn", fun, x)
|
||||
|
||||
imported_fun = mx.import_function("sine.mlxfn")
|
||||
|
||||
# Take the derivative of the imported function
|
||||
dfdx = mx.grad(lambda x: imported_fun(x)[0])
|
||||
# Prints: array(1, dtype=float32)
|
||||
print(dfdx(x))
|
||||
|
||||
# Compile the imported function
|
||||
mx.compile(imported_fun)
|
||||
# Prints: array(0, dtype=float32)
|
||||
print(compiled_fun(x)[0])
|
||||
|
||||
|
||||
Importing Functions in C++
|
||||
--------------------------
|
||||
|
||||
Importing and running functions in C++ is basically the same as importing and
|
||||
running them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to
|
||||
setup a simple C++ project that uses MLX as a library.
|
||||
|
||||
Next, export a simple function from Python:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.exp(x + y)
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
mx.export_function("fun.mlxfn", fun, x, y)
|
||||
|
||||
|
||||
Import and run the function in C++ with only a few lines of code:
|
||||
|
||||
.. code-block:: c++
|
||||
|
||||
auto fun = mx::import_function("fun.mlxfn");
|
||||
|
||||
auto inputs = {mx::array(1.0), mx::array(1.0)};
|
||||
auto outputs = fun(inputs);
|
||||
|
||||
// Prints: array(2, dtype=float32)
|
||||
std::cout << outputs[0] << std::endl;
|
||||
|
||||
Imported functions can be transformed in C++ just like in Python. Use
|
||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||
|
||||
More Examples
|
||||
-------------
|
||||
|
||||
Here are a few more complete examples exporting more complex functions from
|
||||
Python and importing and running them in C++:
|
||||
|
||||
* `Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>`_
|
22
examples/cmake_project/CMakeLists.txt
Normal file
22
examples/cmake_project/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(example LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
# Comment the following two commands only the MLX C++ library is installed and
|
||||
# set(MLX_ROOT "/path/to/mlx") directly if needed.
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
add_executable(example example.cpp)
|
||||
target_link_libraries(example PRIVATE mlx)
|
26
examples/cmake_project/README.md
Normal file
26
examples/cmake_project/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
## Build and Run
|
||||
|
||||
Install MLX with Python:
|
||||
|
||||
```bash
|
||||
pip install mlx>=0.22
|
||||
```
|
||||
|
||||
Build the C++ example:
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
Run the C++ example:
|
||||
|
||||
```
|
||||
./build/example
|
||||
```
|
||||
|
||||
which should output:
|
||||
|
||||
```
|
||||
array([2, 4, 6], dtype=int32)
|
||||
```
|
14
examples/cmake_project/example.cpp
Normal file
14
examples/cmake_project/example.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
auto x = mx::array({1, 2, 3});
|
||||
auto y = mx::array({1, 2, 3});
|
||||
std::cout << x + y << std::endl;
|
||||
return 0;
|
||||
}
|
@@ -4,19 +4,19 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
if (!distributed::is_available()) {
|
||||
if (!mx::distributed::is_available()) {
|
||||
std::cout << "No communication backend found" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto global_group = distributed::init();
|
||||
auto global_group = mx::distributed::init();
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_sum(x, global_group);
|
||||
mx::array x = mx::ones({10});
|
||||
mx::array out = mx::distributed::all_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
||||
|
@@ -10,7 +10,7 @@
|
||||
/**
|
||||
* An example of linear regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
@@ -19,35 +19,35 @@ int main() {
|
||||
float learning_rate = 0.01;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
auto w_star = mx::random::normal({num_features});
|
||||
|
||||
// The input examples (design matrix)
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
auto X = mx::random::normal({num_examples, num_features});
|
||||
|
||||
// Noisy labels
|
||||
auto eps = 1e-2 * random::normal({num_examples});
|
||||
auto y = matmul(X, w_star) + eps;
|
||||
auto eps = 1e-2 * mx::random::normal({num_examples});
|
||||
auto y = mx::matmul(X, w_star) + eps;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto yhat = matmul(X, w);
|
||||
return (0.5f / num_examples) * sum(square(yhat - y));
|
||||
auto loss_fn = [&](mx::array w) {
|
||||
auto yhat = mx::matmul(X, w);
|
||||
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
auto grad_fn = mx::grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
auto grads = grad_fn(w);
|
||||
w = w - learning_rate * grads;
|
||||
mx::eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
|
||||
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
|
||||
<< ", Throughput " << throughput << " (it/s)." << std::endl;
|
||||
|
@@ -10,7 +10,7 @@
|
||||
/**
|
||||
* An example of logistic regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
@@ -19,35 +19,35 @@ int main() {
|
||||
float learning_rate = 0.1;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
auto w_star = mx::random::normal({num_features});
|
||||
|
||||
// The input examples
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
auto X = mx::random::normal({num_examples, num_features});
|
||||
|
||||
// Labels
|
||||
auto y = matmul(X, w_star) > 0;
|
||||
auto y = mx::matmul(X, w_star) > 0;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto logits = matmul(X, w);
|
||||
auto loss_fn = [&](mx::array w) {
|
||||
auto logits = mx::matmul(X, w);
|
||||
auto scale = (1.0f / num_examples);
|
||||
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
|
||||
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
auto grad_fn = mx::grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
auto grads = grad_fn(w);
|
||||
w = w - learning_rate * grads;
|
||||
mx::eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
|
||||
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
|
||||
<< throughput << " (it/s)." << std::endl;
|
||||
|
@@ -5,27 +5,27 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
// To use Metal debugging and profiling:
|
||||
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
||||
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
||||
metal::start_capture("mlx_trace.gputrace");
|
||||
mx::metal::start_capture("mlx_trace.gputrace");
|
||||
|
||||
// Start at index two because the default GPU and CPU streams have indices
|
||||
// zero and one, respectively. This naming matches the label assigned to each
|
||||
// stream's command queue.
|
||||
auto s2 = new_stream(Device::gpu);
|
||||
auto s3 = new_stream(Device::gpu);
|
||||
auto s2 = new_stream(mx::Device::gpu);
|
||||
auto s3 = new_stream(mx::Device::gpu);
|
||||
|
||||
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
||||
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
||||
auto x = add(a, a, s2);
|
||||
auto y = add(b, b, s3);
|
||||
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
|
||||
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
|
||||
auto x = mx::add(a, a, s2);
|
||||
auto y = mx::add(b, b, s3);
|
||||
|
||||
// The multiply will happen on the default stream.
|
||||
std::cout << multiply(x, y) << std::endl;
|
||||
std::cout << mx::multiply(x, y) << std::endl;
|
||||
|
||||
metal::stop_capture();
|
||||
mx::metal::stop_capture();
|
||||
}
|
||||
|
@@ -5,11 +5,11 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void array_basics() {
|
||||
// Make a scalar array:
|
||||
array x(1.0);
|
||||
mx::array x(1.0);
|
||||
|
||||
// Get the value out of it:
|
||||
auto s = x.item<float>();
|
||||
@@ -29,31 +29,31 @@ void array_basics() {
|
||||
|
||||
// The datatype should be float32:
|
||||
auto dtype = x.dtype();
|
||||
assert(dtype == float32);
|
||||
assert(dtype == mx::float32);
|
||||
|
||||
// Specify the dtype when constructing the array:
|
||||
x = array(1, int32);
|
||||
assert(x.dtype() == int32);
|
||||
x = mx::array(1, mx::int32);
|
||||
assert(x.dtype() == mx::int32);
|
||||
x.item<int>(); // OK
|
||||
// x.item<float>(); // Undefined!
|
||||
|
||||
// Make a multidimensional array:
|
||||
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
// mlx is row-major by default so the first row of this array
|
||||
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
||||
|
||||
// Make an array of shape {2, 2} filled with ones:
|
||||
auto y = ones({2, 2});
|
||||
auto y = mx::ones({2, 2});
|
||||
|
||||
// Pointwise add x and y:
|
||||
auto z = add(x, y);
|
||||
auto z = mx::add(x, y);
|
||||
|
||||
// Same thing:
|
||||
z = x + y;
|
||||
|
||||
// mlx is lazy by default. At this point `z` only
|
||||
// has a shape and a type but no actual data:
|
||||
assert(z.dtype() == float32);
|
||||
assert(z.dtype() == mx::float32);
|
||||
assert(z.shape(0) == 2);
|
||||
assert(z.shape(1) == 2);
|
||||
|
||||
@@ -63,33 +63,33 @@ void array_basics() {
|
||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||
// all of its dependencies are recursively evaluated to produce the result.
|
||||
// Once an array is evaluated, it has data and is detached from its inputs.
|
||||
eval(z);
|
||||
mx::eval(z);
|
||||
|
||||
// Of course the array can still be an input to other operations. You can even
|
||||
// call eval on the array again, this will just be a no-op:
|
||||
eval(z); // no-op
|
||||
// Of course the array can still be an input to other operations. You can
|
||||
// even call eval on the array again, this will just be a no-op:
|
||||
mx::eval(z); // no-op
|
||||
|
||||
// Some functions or methods on arrays implicitly evaluate them. For example
|
||||
// accessing a value in an array or printing the array implicitly evaluate it:
|
||||
z = ones({1});
|
||||
z = mx::ones({1});
|
||||
z.item<float>(); // implicit evaluation
|
||||
|
||||
z = ones({2, 2});
|
||||
z = mx::ones({2, 2});
|
||||
std::cout << z << std::endl; // implicit evaluation
|
||||
}
|
||||
|
||||
void automatic_differentiation() {
|
||||
auto fn = [](array x) { return square(x); };
|
||||
auto fn = [](mx::array x) { return mx::square(x); };
|
||||
|
||||
// Computing the derivative function of a function
|
||||
auto grad_fn = grad(fn);
|
||||
auto grad_fn = mx::grad(fn);
|
||||
// Call grad_fn on the input to get the derivative
|
||||
auto x = array(1.5);
|
||||
auto x = mx::array(1.5);
|
||||
auto dfdx = grad_fn(x);
|
||||
// dfdx is 2 * x
|
||||
|
||||
// Get the second derivative by composing grad with grad
|
||||
auto d2fdx2 = grad(grad(fn))(x);
|
||||
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
|
||||
// d2fdx2 is 2
|
||||
}
|
||||
|
||||
|
22
examples/export/CMakeLists.txt
Normal file
22
examples/export/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(import_mlx LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
add_executable(eval_mlp eval_mlp.cpp)
|
||||
target_link_libraries(eval_mlp PRIVATE mlx)
|
||||
|
||||
add_executable(train_mlp train_mlp.cpp)
|
||||
target_link_libraries(train_mlp PRIVATE mlx)
|
49
examples/export/README.md
Normal file
49
examples/export/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
## Setup
|
||||
|
||||
Install MLX:
|
||||
|
||||
```bash
|
||||
pip install mlx>=0.22
|
||||
```
|
||||
|
||||
Build the C++ examples:
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
### Eval MLP
|
||||
|
||||
Run the Python script to export the eval function:
|
||||
|
||||
```bash
|
||||
python eval_mlp.py
|
||||
```
|
||||
|
||||
Then run the C++ program to import and run the function:
|
||||
|
||||
```
|
||||
./build/eval_mlp
|
||||
```
|
||||
|
||||
The Python and C++ programs should output the same result.
|
||||
|
||||
### Train MLP
|
||||
|
||||
Run the Python script to export the model initialization and training
|
||||
functions:
|
||||
|
||||
```bash
|
||||
python train_mlp.py
|
||||
```
|
||||
|
||||
Then run the C++ program to import and run the functions:
|
||||
|
||||
```
|
||||
./build/train_mlp
|
||||
```
|
||||
|
||||
The Python and C++ programs should output the same results.
|
25
examples/export/eval_mlp.cpp
Normal file
25
examples/export/eval_mlp.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <mlx/mlx.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int batch_size = 8;
|
||||
int input_dim = 32;
|
||||
|
||||
// Make the input
|
||||
mx::random::seed(42);
|
||||
auto example_x = mx::random::uniform({batch_size, input_dim});
|
||||
|
||||
// Import the function
|
||||
auto forward = mx::import_function("eval_mlp.mlxfn");
|
||||
|
||||
// Call the imported function
|
||||
auto out = forward({example_x})[0];
|
||||
|
||||
std::cout << out << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
52
examples/export/eval_mlp.py
Normal file
52
examples/export/eval_mlp.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""A simple MLP."""
|
||||
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = nn.relu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
batch_size = 8
|
||||
input_dim = 32
|
||||
output_dim = 10
|
||||
|
||||
# Load the model
|
||||
mx.random.seed(0) # Seed for params
|
||||
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
|
||||
mx.eval(model)
|
||||
|
||||
# Note, the model parameters are saved in the export function
|
||||
def forward(x):
|
||||
return model(x)
|
||||
|
||||
mx.random.seed(42) # Seed for input
|
||||
example_x = mx.random.uniform(shape=(batch_size, input_dim))
|
||||
|
||||
mx.export_function("eval_mlp.mlxfn", forward, example_x)
|
||||
|
||||
# Import in Python
|
||||
imported_forward = mx.import_function("eval_mlp.mlxfn")
|
||||
expected = forward(example_x)
|
||||
(out,) = imported_forward(example_x)
|
||||
assert mx.allclose(expected, out)
|
||||
print(out)
|
35
examples/export/train_mlp.cpp
Normal file
35
examples/export/train_mlp.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <mlx/mlx.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int batch_size = 8;
|
||||
int input_dim = 32;
|
||||
int output_dim = 10;
|
||||
|
||||
auto state = mx::import_function("init_mlp.mlxfn")({});
|
||||
|
||||
// Make the input
|
||||
mx::random::seed(42);
|
||||
auto example_X = mx::random::normal({batch_size, input_dim});
|
||||
auto example_y = mx::random::randint(0, output_dim, {batch_size});
|
||||
|
||||
// Import the function
|
||||
auto step = mx::import_function("train_mlp.mlxfn");
|
||||
|
||||
// Call the imported function
|
||||
for (int it = 0; it < 100; ++it) {
|
||||
state.insert(state.end(), {example_X, example_y});
|
||||
state = step(state);
|
||||
eval(state);
|
||||
auto loss = state.back();
|
||||
state.pop_back();
|
||||
if (it % 10 == 0) {
|
||||
std::cout << "Loss " << loss.item<float>() << std::endl;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
76
examples/export/train_mlp.py
Normal file
76
examples/export/train_mlp.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""A simple MLP."""
|
||||
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = nn.relu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
batch_size = 8
|
||||
input_dim = 32
|
||||
output_dim = 10
|
||||
|
||||
def init():
|
||||
# Seed for the parameter initialization
|
||||
mx.random.seed(0)
|
||||
model = MLP(
|
||||
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
|
||||
)
|
||||
optimizer = optim.SGD(learning_rate=1e-1)
|
||||
optimizer.init(model.parameters())
|
||||
state = [model.parameters(), optimizer.state]
|
||||
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
|
||||
return model, optimizer, tree_structure, state
|
||||
|
||||
# Export the model parameter initialization
|
||||
model, optimizer, tree_structure, state = init()
|
||||
mx.eval(state)
|
||||
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
|
||||
|
||||
def loss_fn(params, X, y):
|
||||
model.update(params)
|
||||
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||
|
||||
def step(*inputs):
|
||||
*state, X, y = inputs
|
||||
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
|
||||
optimizer.state = opt_state
|
||||
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
|
||||
params = optimizer.apply_gradients(grads, params)
|
||||
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
|
||||
return *state, loss
|
||||
|
||||
# Make some random data
|
||||
mx.random.seed(42)
|
||||
example_X = mx.random.normal(shape=(batch_size, input_dim))
|
||||
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
|
||||
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
|
||||
|
||||
# Export one step of SGD
|
||||
imported_step = mx.import_function("train_mlp.mlxfn")
|
||||
|
||||
for it in range(100):
|
||||
*state, loss = imported_step(*state, example_X, example_y)
|
||||
if it % 10 == 0:
|
||||
print(f"Loss {loss.item():.6}")
|
@@ -18,8 +18,7 @@ find_package(
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
OUTPUT_VARIABLE nanobind_ROOT)
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
# ----------------------------- Extensions -----------------------------
|
||||
|
@@ -19,7 +19,7 @@
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
namespace my_ext {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation Implementation
|
||||
@@ -32,24 +32,24 @@ namespace mlx::core {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input mx::array x
|
||||
const mx::array& y, // Input mx::array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Promote dtypes between x and y as needed
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
: promote_types(promoted_dtype, mx::float32);
|
||||
|
||||
// Cast x and y up to the determined dtype (on the same stream s)
|
||||
auto x_casted = astype(x, out_dtype, s);
|
||||
auto y_casted = astype(y, out_dtype, s);
|
||||
auto x_casted = mx::astype(x, out_dtype, s);
|
||||
auto y_casted = mx::astype(y, out_dtype, s);
|
||||
|
||||
// Broadcast the shapes of x and y (on the same stream s)
|
||||
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||
@@ -57,12 +57,12 @@ array axpby(
|
||||
|
||||
// Construct the array as the output of the Axpby primitive
|
||||
// with the broadcasted and upcasted arrays as inputs
|
||||
return array(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
return mx::array(
|
||||
/* const mx::Shape& shape = */ out_shape,
|
||||
/* mx::Dtype dtype = */ out_dtype,
|
||||
/* std::shared_ptr<mx::Primitive> primitive = */
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -71,16 +71,16 @@ array axpby(
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
@@ -94,8 +94,8 @@ void axpby_impl(
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
@@ -105,8 +105,8 @@ void axpby_impl(
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
@@ -114,14 +114,14 @@ void Axpby::eval(
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
if (out.dtype() == mx::float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
@@ -136,9 +136,9 @@ void Axpby::eval(
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
@@ -150,10 +150,10 @@ void axpby_impl_accelerate(
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
copy_inplace(y, out, mx::CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
@@ -175,15 +175,15 @@ void axpby_impl_accelerate(
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
if (out.dtype() == mx::float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
@@ -198,8 +198,8 @@ void Axpby::eval_cpu(
|
||||
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
@@ -213,8 +213,8 @@ void Axpby::eval_cpu(
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
@@ -225,7 +225,7 @@ void Axpby::eval_gpu(
|
||||
// and each stream carries its device identifiers
|
||||
auto& s = stream();
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = metal::device(s.device);
|
||||
auto& d = mx::metal::device(s.device);
|
||||
|
||||
// Prepare to specialize based on contiguity
|
||||
bool contiguous_kernel =
|
||||
@@ -235,12 +235,12 @@ void Axpby::eval_gpu(
|
||||
// Allocate output memory with strides based on specialization
|
||||
if (contiguous_kernel) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
@@ -279,7 +279,7 @@ void Axpby::eval_gpu(
|
||||
if (!contiguous_kernel) {
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_vector_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
}
|
||||
|
||||
@@ -302,8 +302,8 @@ void Axpby::eval_gpu(
|
||||
|
||||
/** Fail evaluation on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& out) {
|
||||
throw std::runtime_error("Axpby has no GPU implementation.");
|
||||
}
|
||||
|
||||
@@ -314,9 +314,9 @@ void Axpby::eval_gpu(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> Axpby::jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can built with ops
|
||||
@@ -328,8 +328,8 @@ std::vector<array> Axpby::jvp(
|
||||
// scaled by beta
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
auto scale_arr = mx::array(scale, tangents[0].dtype());
|
||||
return {mx::multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
@@ -339,24 +339,24 @@ std::vector<array> Axpby::jvp(
|
||||
}
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
std::vector<mx::array> Axpby::vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
const std::vector<mx::array>&) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
std::vector<mx::array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||
auto scale_arr = mx::array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
}
|
||||
@@ -367,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
|
||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace my_ext
|
||||
|
@@ -5,7 +5,9 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
namespace mx = mlx::core;
|
||||
|
||||
namespace my_ext {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation
|
||||
@@ -18,22 +20,22 @@ namespace mlx::core {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input array x
|
||||
const mx::array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class Axpby : public Primitive {
|
||||
class Axpby : public mx::Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
explicit Axpby(mx::Stream stream, float alpha, float beta)
|
||||
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
@@ -42,23 +44,25 @@ class Axpby : public Primitive {
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
void eval_gpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
std::vector<mx::array> vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
const std::vector<mx::array>& outputs) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself across
|
||||
@@ -66,8 +70,8 @@ class Axpby : public Primitive {
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
@@ -76,14 +80,16 @@ class Axpby : public Primitive {
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
bool is_equivalent(const mx::Primitive& other) const override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
void eval(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace my_ext
|
||||
|
@@ -8,14 +8,12 @@
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
&my_ext::axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
"alpha"_a,
|
||||
|
@@ -1,8 +1,8 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.24",
|
||||
"cmake>=3.25",
|
||||
"mlx>=0.18.0",
|
||||
"nanobind==2.2.0",
|
||||
"nanobind==2.4.0",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
cmake>=3.25
|
||||
mlx>=0.21.0
|
||||
nanobind==2.2.0
|
||||
|
@@ -5,6 +5,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
@@ -18,6 +19,16 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
if(MSVC)
|
||||
# Disable some MSVC warnings to speed up compilation.
|
||||
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
# Export symbols by default to behave like macOS/linux.
|
||||
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
else()
|
||||
|
@@ -10,22 +10,8 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||
auto cval = static_cast<complex64_t>(val);
|
||||
init(&cval);
|
||||
}
|
||||
@@ -61,14 +47,14 @@ std::vector<array> array::make_arrays(
|
||||
|
||||
array::array(std::initializer_list<float> data)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
float32)) {
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
dtype)) {
|
||||
init(data.begin());
|
||||
}
|
||||
@@ -119,7 +105,8 @@ void array::eval() {
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||
return (array_desc_->is_tracer && detail::in_tracing()) ||
|
||||
detail::retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
@@ -277,7 +264,19 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
}
|
||||
ad.inputs.clear();
|
||||
for (auto& [_, a] : input_map) {
|
||||
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
|
||||
bool is_deletable =
|
||||
(a.array_desc_.use_count() <= a.siblings().size() + 1);
|
||||
// An array with siblings is deletable only if all of its siblings
|
||||
// are deletable
|
||||
for (auto& s : a.siblings()) {
|
||||
if (!is_deletable) {
|
||||
break;
|
||||
}
|
||||
int is_input = (input_map.find(s.id()) != input_map.end());
|
||||
is_deletable &=
|
||||
s.array_desc_.use_count() <= a.siblings().size() + is_input;
|
||||
}
|
||||
if (is_deletable) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
}
|
||||
@@ -310,7 +309,7 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
}
|
||||
|
||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||
auto start = std::vector<int>(arr.ndim(), 0);
|
||||
auto start = Shape(arr.ndim(), 0);
|
||||
auto end = arr.shape();
|
||||
auto shape = arr.shape();
|
||||
shape.erase(shape.begin());
|
||||
|
19
mlx/array.h
19
mlx/array.h
@@ -17,7 +17,8 @@ namespace mlx::core {
|
||||
class Primitive;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using Shape = std::vector<int32_t>;
|
||||
using ShapeElem = int32_t;
|
||||
using Shape = std::vector<ShapeElem>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
|
||||
class array {
|
||||
@@ -34,29 +35,29 @@ class array {
|
||||
explicit array(const std::complex<float>& val, Dtype dtype = complex64);
|
||||
|
||||
template <typename It>
|
||||
array(
|
||||
explicit array(
|
||||
It data,
|
||||
Shape shape,
|
||||
Dtype dtype =
|
||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||
|
||||
template <typename T>
|
||||
array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
|
||||
explicit array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Special case so empty lists default to float32. */
|
||||
array(std::initializer_list<float> data);
|
||||
explicit array(std::initializer_list<float> data);
|
||||
|
||||
/* Special case so array({}, type) is an empty array. */
|
||||
array(std::initializer_list<int> data, Dtype dtype);
|
||||
explicit array(std::initializer_list<int> data, Dtype dtype);
|
||||
|
||||
template <typename T>
|
||||
array(
|
||||
explicit array(
|
||||
std::initializer_list<T> data,
|
||||
Shape shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a buffer */
|
||||
array(
|
||||
explicit array(
|
||||
allocator::Buffer data,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
@@ -498,7 +499,7 @@ class array {
|
||||
|
||||
template <typename T>
|
||||
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||
init(&val);
|
||||
}
|
||||
|
||||
@@ -516,7 +517,7 @@ array::array(
|
||||
std::initializer_list<T> data,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
dtype)) {
|
||||
init(data.begin());
|
||||
}
|
||||
|
@@ -32,6 +32,7 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BroadcastAxes)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
@@ -43,6 +44,7 @@ DEFAULT(NumberOfElements)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
@@ -65,7 +67,6 @@ DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
@@ -76,6 +77,7 @@ DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
|
@@ -5,13 +5,21 @@ else()
|
||||
set(COMPILER ${CMAKE_CXX_COMPILER})
|
||||
endif()
|
||||
|
||||
if(MSVC)
|
||||
set(SHELL_EXT ps1)
|
||||
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
|
||||
else()
|
||||
set(SHELL_EXT sh)
|
||||
set(SHELL_CMD /bin/bash)
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND
|
||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG}
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
|
||||
DEPENDS make_compiled_preamble.${SHELL_EXT}
|
||||
compiled_preamble.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
@@ -58,5 +66,6 @@ target_sources(
|
||||
if(IOS)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)
|
||||
endif()
|
||||
|
@@ -28,8 +28,8 @@ BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
} else if (b.data_size() == 1 && a.flags().contiguous) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
} else if (
|
||||
a.flags().row_contiguous && b.flags().row_contiguous ||
|
||||
a.flags().col_contiguous && b.flags().col_contiguous) {
|
||||
(a.flags().row_contiguous && b.flags().row_contiguous) ||
|
||||
(a.flags().col_contiguous && b.flags().col_contiguous)) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
} else {
|
||||
bopt = BinaryOpType::General;
|
||||
|
@@ -42,9 +42,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
return move_or_copy(in, out, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
@@ -61,6 +59,14 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
move_or_copy(in, out, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
|
||||
void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
move_or_copy(inputs[0], out);
|
||||
@@ -85,6 +91,16 @@ void Depends::eval(
|
||||
}
|
||||
}
|
||||
|
||||
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
auto strides = in.strides();
|
||||
for (auto ax : axes_) {
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
}
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
@@ -141,9 +157,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, Strides> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
@@ -180,7 +194,7 @@ std::pair<bool, Strides> Reshape::prepare_reshape(
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out) {
|
||||
@@ -248,6 +262,20 @@ void Split::eval(
|
||||
}
|
||||
}
|
||||
|
||||
void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
Strides strides;
|
||||
for (int i = 0, j = 0; i < in.ndim(); ++i) {
|
||||
if (j < axes_.size() && i == axes_[j]) {
|
||||
j++;
|
||||
} else {
|
||||
strides.push_back(in.strides(i));
|
||||
}
|
||||
}
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
move_or_copy(inputs[0], out);
|
||||
|
@@ -130,7 +130,7 @@ std::string build_lib_name(
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape) {
|
||||
const Shape& shape) {
|
||||
bool contiguous = true;
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
|
@@ -11,9 +11,7 @@
|
||||
namespace mlx::core {
|
||||
|
||||
inline bool is_static_cast(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
|
||||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
|
||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
@@ -56,7 +54,7 @@ inline bool is_scalar(const array& x) {
|
||||
// Check if we can use a contiguous operation given inputs and the output shape
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape);
|
||||
const Shape& shape);
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
void compiled_allocate_outputs(
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
#include <format>
|
||||
#include <fstream>
|
||||
#include <list>
|
||||
#include <mutex>
|
||||
@@ -9,6 +10,7 @@
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/compiled_preamble.h"
|
||||
#include "mlx/backend/common/jit_compiler.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
|
||||
@@ -44,11 +46,8 @@ namespace detail {
|
||||
bool compile_available_for_device(const Device& device) {
|
||||
return true;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::string get_temp_file(const std::string& name) {
|
||||
return std::filesystem::temp_directory_path().append(name).string();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// Return a pointer to a compiled function
|
||||
void* compile(
|
||||
@@ -68,24 +67,30 @@ void* compile(
|
||||
std::string source_code = source_builder();
|
||||
std::string kernel_file_name;
|
||||
|
||||
// Deal with long kernel names. Maximum length for files on macOS is 255
|
||||
// characters. Clip file name with a little extra room and append a 16
|
||||
// character hash.
|
||||
// Deal with long kernel names. Maximum length for filename on macOS is 255
|
||||
// characters, and on Windows the maximum length for whole path is 260. Clip
|
||||
// file name with a little extra room and append a 16 character hash.
|
||||
#ifdef _WIN32
|
||||
constexpr int max_file_name_length = 140;
|
||||
#else
|
||||
constexpr int max_file_name_length = 245;
|
||||
#endif
|
||||
if (kernel_name.size() > max_file_name_length) {
|
||||
std::ostringstream file_name;
|
||||
file_name
|
||||
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
|
||||
auto file_id = std::hash<std::string>{}(kernel_name);
|
||||
auto file_id =
|
||||
std::hash<std::string>{}(kernel_name.substr(max_file_name_length - 16));
|
||||
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
|
||||
kernel_file_name = file_name.str();
|
||||
} else {
|
||||
kernel_file_name = kernel_name;
|
||||
}
|
||||
|
||||
std::ostringstream shared_lib_name;
|
||||
shared_lib_name << "lib" << kernel_file_name << ".so";
|
||||
auto shared_lib_path = get_temp_file(shared_lib_name.str());
|
||||
auto output_dir = std::filesystem::temp_directory_path();
|
||||
|
||||
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
|
||||
auto shared_lib_path = (output_dir / shared_lib_name).string();
|
||||
bool lib_exists = false;
|
||||
{
|
||||
std::ifstream f(shared_lib_path.c_str());
|
||||
@@ -94,24 +99,21 @@ void* compile(
|
||||
|
||||
if (!lib_exists) {
|
||||
// Open source file and write source code to it
|
||||
std::ostringstream source_file_name;
|
||||
source_file_name << kernel_file_name << ".cpp";
|
||||
auto source_file_path = get_temp_file(source_file_name.str());
|
||||
std::string source_file_name = kernel_file_name + ".cpp";
|
||||
auto source_file_path = (output_dir / source_file_name).string();
|
||||
|
||||
std::ofstream source_file(source_file_path);
|
||||
source_file << source_code;
|
||||
source_file.close();
|
||||
|
||||
std::ostringstream build_command;
|
||||
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
|
||||
<< source_file_path << "' -o '" << shared_lib_path << "'";
|
||||
std::string build_command_str = build_command.str();
|
||||
auto return_code = system(build_command_str.c_str());
|
||||
if (return_code) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
|
||||
<< " with error code " << return_code << "." << std::endl;
|
||||
throw std::runtime_error(msg.str());
|
||||
try {
|
||||
JitCompiler::exec(JitCompiler::build_command(
|
||||
output_dir, source_file_name, shared_lib_name));
|
||||
} catch (const std::exception& error) {
|
||||
throw std::runtime_error(std::format(
|
||||
"[Compile::eval_cpu] Failed to compile function {0}: {1}",
|
||||
kernel_name,
|
||||
error.what()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,6 +153,11 @@ inline void build_kernel(
|
||||
|
||||
NodeNamer namer;
|
||||
|
||||
#ifdef _MSC_VER
|
||||
// Export the symbol
|
||||
os << "__declspec(dllexport) ";
|
||||
#endif
|
||||
|
||||
// Start the kernel
|
||||
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
||||
|
||||
|
@@ -726,7 +726,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C};
|
||||
Shape padded_shape = {N, iH + 2 * padding[0], C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
@@ -765,7 +765,7 @@ void explicit_gemm_conv_1D_cpu(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N * oH, wH * C};
|
||||
Shape strided_reshape = {N * oH, wH * C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General);
|
||||
|
||||
@@ -843,8 +843,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
auto conv_dtype = out.dtype();
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {
|
||||
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
@@ -881,7 +880,7 @@ void explicit_gemm_conv_2D_cpu(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C};
|
||||
Shape strided_reshape = {N * oH * oW, wH * wW * C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy(in_strided_view, in_strided, CopyType::General);
|
||||
|
||||
@@ -934,19 +933,19 @@ void explicit_gemm_conv_ND_cpu(
|
||||
const std::vector<int>& wt_dilation,
|
||||
const bool flip) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const auto iDim = std::vector<int>(
|
||||
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||
const auto oDim = std::vector<int>(
|
||||
const auto iDim =
|
||||
Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||
const auto oDim = Shape(
|
||||
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(-1); // In channels
|
||||
const auto wDim = std::vector<int>(
|
||||
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
||||
const auto wDim =
|
||||
Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
|
||||
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_shape(in.shape().size());
|
||||
Shape padded_shape(in.shape().size());
|
||||
padded_shape.front() = N;
|
||||
for (size_t i = 0; i < iDim.size(); i++) {
|
||||
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
||||
|
@@ -37,6 +37,7 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BroadcastAxes)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
@@ -57,6 +58,7 @@ DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(Expm1)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
@@ -86,7 +88,6 @@ DEFAULT_MULTI(QRF)
|
||||
DEFAULT(QuantizedMatmul)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
@@ -101,6 +102,7 @@ DEFAULT(Softmax)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Square)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(Sqrt)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Subtract)
|
||||
|
153
mlx/backend/common/jit_compiler.cpp
Normal file
153
mlx/backend/common/jit_compiler.cpp
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/jit_compiler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include <format>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#ifdef _MSC_VER
|
||||
|
||||
namespace {
|
||||
|
||||
// Split string into array.
|
||||
std::vector<std::string> str_split(const std::string& str, char delimiter) {
|
||||
std::vector<std::string> tokens;
|
||||
std::string token;
|
||||
std::istringstream tokenStream(str);
|
||||
while (std::getline(tokenStream, token, delimiter)) {
|
||||
tokens.push_back(token);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
// Get path information about MSVC.
|
||||
struct VisualStudioInfo {
|
||||
VisualStudioInfo() {
|
||||
#ifdef _M_ARM64
|
||||
arch = "arm64";
|
||||
#else
|
||||
arch = "x64";
|
||||
#endif
|
||||
// Get path of Visual Studio.
|
||||
std::string vs_path = JitCompiler::exec(std::format(
|
||||
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
|
||||
" -property installationPath",
|
||||
std::getenv("ProgramFiles(x86)")));
|
||||
if (vs_path.empty()) {
|
||||
throw std::runtime_error("Can not find Visual Studio.");
|
||||
}
|
||||
// Read the envs from vcvarsall.
|
||||
std::string envs = JitCompiler::exec(std::format(
|
||||
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
|
||||
vs_path,
|
||||
arch));
|
||||
for (const std::string& line : str_split(envs, '\n')) {
|
||||
// Each line is in the format "ENV_NAME=values".
|
||||
auto pos = line.find_first_of('=');
|
||||
if (pos == std::string::npos || pos == 0 || pos == line.size() - 1)
|
||||
continue;
|
||||
std::string name = line.substr(0, pos);
|
||||
std::string value = line.substr(pos + 1);
|
||||
if (name == "LIB") {
|
||||
libpaths = str_split(value, ';');
|
||||
} else if (name == "VCToolsInstallDir") {
|
||||
cl_exe = std::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::string arch;
|
||||
std::string cl_exe;
|
||||
std::vector<std::string> libpaths;
|
||||
};
|
||||
|
||||
const VisualStudioInfo& GetVisualStudioInfo() {
|
||||
static VisualStudioInfo info;
|
||||
return info;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif // _MSC_VER
|
||||
|
||||
std::string JitCompiler::build_command(
|
||||
const std::filesystem::path& dir,
|
||||
const std::string& source_file_name,
|
||||
const std::string& shared_lib_name) {
|
||||
#ifdef _MSC_VER
|
||||
const VisualStudioInfo& info = GetVisualStudioInfo();
|
||||
std::string libpaths;
|
||||
for (const std::string& lib : info.libpaths) {
|
||||
libpaths += std::format(" /libpath:\"{0}\"", lib);
|
||||
}
|
||||
return std::format(
|
||||
"\""
|
||||
"cd /D \"{0}\" && "
|
||||
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
|
||||
"/link /out:\"{3}\" {4} 2>&1"
|
||||
"\"",
|
||||
dir.string(),
|
||||
info.cl_exe,
|
||||
source_file_name,
|
||||
shared_lib_name,
|
||||
libpaths);
|
||||
#else
|
||||
return std::format(
|
||||
"g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}' 2>&1",
|
||||
(dir / source_file_name).string(),
|
||||
(dir / shared_lib_name).string());
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string JitCompiler::exec(const std::string& cmd) {
|
||||
#ifdef _MSC_VER
|
||||
FILE* pipe = _popen(cmd.c_str(), "r");
|
||||
#else
|
||||
FILE* pipe = popen(cmd.c_str(), "r");
|
||||
#endif
|
||||
if (!pipe) {
|
||||
throw std::runtime_error("popen() failed.");
|
||||
}
|
||||
char buffer[128];
|
||||
std::string ret;
|
||||
while (fgets(buffer, sizeof(buffer), pipe)) {
|
||||
ret += buffer;
|
||||
}
|
||||
// Trim trailing spaces.
|
||||
ret.erase(
|
||||
std::find_if(
|
||||
ret.rbegin(),
|
||||
ret.rend(),
|
||||
[](unsigned char ch) { return !std::isspace(ch); })
|
||||
.base(),
|
||||
ret.end());
|
||||
|
||||
#ifdef _MSC_VER
|
||||
int status = _pclose(pipe);
|
||||
#else
|
||||
int status = pclose(pipe);
|
||||
#endif
|
||||
if (status == -1) {
|
||||
throw std::runtime_error("pclose() failed.");
|
||||
}
|
||||
#ifdef _MSC_VER
|
||||
int code = status;
|
||||
#else
|
||||
int code = WEXITSTATUS(status);
|
||||
#endif
|
||||
if (code != 0) {
|
||||
throw std::runtime_error(std::format(
|
||||
"Failed to execute command with return code {0}: \"{1}\", "
|
||||
"the output is: {2}",
|
||||
code,
|
||||
cmd,
|
||||
ret));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
20
mlx/backend/common/jit_compiler.h
Normal file
20
mlx/backend/common/jit_compiler.h
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <filesystem>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
class JitCompiler {
|
||||
public:
|
||||
// Build a shell command that compiles a source code file to a shared library.
|
||||
static std::string build_command(
|
||||
const std::filesystem::path& dir,
|
||||
const std::string& source_file_name,
|
||||
const std::string& shared_lib_name);
|
||||
|
||||
// Run a command and get its output.
|
||||
static std::string exec(const std::string& cmd);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
38
mlx/backend/common/make_compiled_preamble.ps1
Normal file
38
mlx/backend/common/make_compiled_preamble.ps1
Normal file
@@ -0,0 +1,38 @@
|
||||
# This script generates a C++ function that provides the CPU
|
||||
# code for use with kernel generation.
|
||||
#
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
$OUTPUT_FILE = $args[0]
|
||||
$CL = $args[1]
|
||||
$SRCDIR = $args[2]
|
||||
|
||||
# Get command result as array.
|
||||
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/compiled_preamble.h"
|
||||
# Remove empty lines.
|
||||
# Otherwise there will be too much empty lines making the result unreadable.
|
||||
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
|
||||
# Concatenate to string.
|
||||
$CONTENT = $CONTENT -join "`n"
|
||||
|
||||
# Append extra content.
|
||||
$CONTENT = @"
|
||||
$($CONTENT)
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::detail;
|
||||
"@
|
||||
|
||||
# Convert each char to ASCII code.
|
||||
# Unlike the unix script that outputs string literal directly, the output from
|
||||
# MSVC is way too large to be embedded as string and compilation will fail, so
|
||||
# we store it as static array instead.
|
||||
$CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0'
|
||||
|
||||
$OUTPUT = @"
|
||||
const char* get_kernel_preamble() {
|
||||
static char preamble[] = { $CHARCODES };
|
||||
return preamble;
|
||||
}
|
||||
"@
|
||||
|
||||
Set-Content -Path $OUTPUT_FILE -Value $OUTPUT
|
@@ -10,15 +10,16 @@ OUTPUT_FILE=$1
|
||||
GCC=$2
|
||||
SRCDIR=$3
|
||||
CLANG=$4
|
||||
ARCH=$5
|
||||
|
||||
if [ "$CLANG" = "TRUE" ]; then
|
||||
read -r -d '' INCLUDES <<- EOM
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
EOM
|
||||
CC_FLAGS=""
|
||||
CC_FLAGS="-arch ${ARCH}"
|
||||
else
|
||||
CC_FLAGS="-std=c++17"
|
||||
fi
|
||||
|
@@ -19,6 +19,45 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void reshape(const array& in, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes) {
|
||||
auto compute_offset = [&strides, &axes](const auto* indices) {
|
||||
int64_t offset = 0;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
offset += indices[i] * strides[axes[i]];
|
||||
}
|
||||
return offset;
|
||||
};
|
||||
switch (indices.dtype()) {
|
||||
case int8:
|
||||
case uint8:
|
||||
return compute_offset(indices.data<uint8_t>());
|
||||
case int16:
|
||||
case uint16:
|
||||
return compute_offset(indices.data<uint16_t>());
|
||||
case int32:
|
||||
case uint32:
|
||||
return compute_offset(indices.data<uint32_t>());
|
||||
case int64:
|
||||
case uint64:
|
||||
return compute_offset(indices.data<uint64_t>());
|
||||
default:
|
||||
throw std::runtime_error("Invalid indices type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -258,6 +297,14 @@ void Expm1::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -417,18 +464,8 @@ void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
||||
}
|
||||
|
||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out);
|
||||
}
|
||||
|
||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -499,34 +536,64 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
auto copy_needed = std::any_of(
|
||||
strides_.begin(), strides_.end(), [](auto i) { return i < 0; });
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
Strides ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General);
|
||||
} else {
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
Strides ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
Strides ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_);
|
||||
copy_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ out.shape(),
|
||||
/* const Strides& i_strides = */ in.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ i_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void DynamicSliceUpdate::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
|
||||
// Copy or move src to dst
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
|
||||
auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_);
|
||||
copy_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
|
||||
/* const std::vector<stride_t>& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ o_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -554,12 +621,11 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
Strides upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ upd_strides,
|
||||
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
|
||||
/* const std::vector<stride_t>& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
@@ -615,7 +681,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// - type size is the same
|
||||
// - type size is smaller and the last axis is contiguous
|
||||
// - the entire array is row contiguous
|
||||
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
|
||||
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
|
||||
|
@@ -14,10 +14,10 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename IdxT = int32_t>
|
||||
template <typename T>
|
||||
struct StridedIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
using difference_type = IdxT;
|
||||
using difference_type = int32_t;
|
||||
using value_type = T;
|
||||
using reference = value_type&;
|
||||
using pointer = value_type*;
|
||||
|
@@ -67,7 +67,12 @@ void set_ternary_op_output_data(
|
||||
}
|
||||
break;
|
||||
case TernaryOpType::General:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
// Try to donate an input which is row_contiguous
|
||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@@ -107,7 +107,7 @@ struct ContiguousIterator {
|
||||
: shape_(a.shape()), strides_(a.strides()) {
|
||||
if (!shape_.empty()) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
pos_ = std::vector<int>(shape_.size(), 0);
|
||||
pos_ = Shape(shape_.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,4 +168,10 @@ void move_or_copy(
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
|
||||
|
||||
void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
} // namespace mlx::core
|
||||
|
@@ -34,16 +34,20 @@ BufferCache::~BufferCache() {
|
||||
clear();
|
||||
}
|
||||
|
||||
void BufferCache::clear() {
|
||||
int BufferCache::clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
if (holder->buf)
|
||||
if (holder->buf) {
|
||||
holder->buf->release();
|
||||
n_release++;
|
||||
}
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
return n_release;
|
||||
}
|
||||
|
||||
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||
@@ -81,10 +85,11 @@ void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
||||
}
|
||||
}
|
||||
|
||||
void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
clear();
|
||||
return clear();
|
||||
} else {
|
||||
int n_release = 0;
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
@@ -92,10 +97,12 @@ void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
total_bytes_freed += tail_->buf->length();
|
||||
tail_->buf->release();
|
||||
tail_->buf = nullptr;
|
||||
n_release++;
|
||||
}
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return n_release;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,11 +151,11 @@ MetalAllocator::MetalAllocator()
|
||||
residency_set_(device_),
|
||||
buffer_cache_(device_) {
|
||||
auto memsize = std::get<size_t>(device_info()["memory_size"]);
|
||||
block_limit_ =
|
||||
std::min(1.5 * device_->recommendedMaxWorkingSetSize(), 0.95 * memsize);
|
||||
gc_limit_ = std::min(
|
||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
|
||||
block_limit_);
|
||||
auto max_rec_size =
|
||||
std::get<size_t>(device_info()["max_recommended_working_set_size"]);
|
||||
resource_limit_ = std::get<size_t>(device_info()["resource_limit"]);
|
||||
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
|
||||
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
|
||||
max_pool_size_ = block_limit_;
|
||||
device(mlx::core::Device::gpu)
|
||||
.set_residency_set(residency_set_.mtl_residency_set());
|
||||
@@ -186,7 +193,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
// More helpful message if maximum buffer length is exceeded
|
||||
if (size > device_->maxBufferLength()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Attempting to allocate " << size << " bytes which is greater than"
|
||||
msg << "[metal::malloc] Attempting to allocate " << size
|
||||
<< " bytes which is greater than"
|
||||
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
|
||||
<< " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
@@ -212,16 +220,26 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache
|
||||
if (mem_required >= gc_limit_) {
|
||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
|
||||
num_resources_ -=
|
||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||
}
|
||||
|
||||
// Allocate new buffer if needed
|
||||
size_t res_opt = MTL::ResourceStorageModeShared;
|
||||
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
|
||||
if (num_resources_ >= resource_limit_) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::malloc] Resource limit (" << resource_limit_
|
||||
<< ") exceeded.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
lk.unlock();
|
||||
buf = device_->newBuffer(size, res_opt);
|
||||
lk.lock();
|
||||
if (buf) {
|
||||
num_resources_++;
|
||||
}
|
||||
}
|
||||
|
||||
active_memory_ += buf->length();
|
||||
@@ -230,7 +248,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
// Maintain the cache below the requested limit
|
||||
if (get_cache_memory() >= max_pool_size_) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
num_resources_ -= buffer_cache_.release_cached_buffers(
|
||||
get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
residency_set_.insert(buf);
|
||||
@@ -241,7 +260,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
void MetalAllocator::clear_cache() {
|
||||
std::unique_lock lk(mutex_);
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
buffer_cache_.clear();
|
||||
num_resources_ -= buffer_cache_.clear();
|
||||
}
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
@@ -255,6 +274,7 @@ void MetalAllocator::free(Buffer buffer) {
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
num_resources_--;
|
||||
lk.unlock();
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
buf->release();
|
||||
|
@@ -23,11 +23,11 @@ class BufferCache {
|
||||
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
void recycle_to_cache(MTL::Buffer* buf);
|
||||
void release_cached_buffers(size_t min_bytes_to_free);
|
||||
int release_cached_buffers(size_t min_bytes_to_free);
|
||||
size_t cache_size() {
|
||||
return pool_size_;
|
||||
}
|
||||
void clear();
|
||||
int clear();
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
@@ -94,6 +94,8 @@ class MetalAllocator : public allocator::Allocator {
|
||||
size_t max_pool_size_;
|
||||
size_t wired_limit_{0};
|
||||
bool relaxed_{true};
|
||||
size_t num_resources_{0};
|
||||
size_t resource_limit_{0};
|
||||
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
@@ -81,13 +81,15 @@ void binary_op_gpu_inplace(
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
||||
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
bool large;
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread;
|
||||
if (bopt == BinaryOpType::General) {
|
||||
large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX);
|
||||
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
out.size() > INT32_MAX;
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
work_per_thread = 1;
|
||||
}
|
||||
std::string kernel_name =
|
||||
|
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <fmt/format.h>
|
||||
#include <iostream> //TODO
|
||||
#include <format>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
@@ -12,8 +11,6 @@
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline void build_kernel(
|
||||
@@ -42,7 +39,7 @@ inline void build_kernel(
|
||||
int cnt = 0;
|
||||
|
||||
// Start the kernel
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
|
||||
|
||||
// Add the input arguments
|
||||
@@ -58,7 +55,7 @@ inline void build_kernel(
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
add_indices = true;
|
||||
}
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" device const {0}* {1} [[buffer({2})]],\n",
|
||||
get_type_string(x.dtype()),
|
||||
xname,
|
||||
@@ -66,13 +63,13 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
if (add_indices) {
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" device {0}* {1} [[buffer({2})]],\n",
|
||||
get_type_string(x.dtype()),
|
||||
namer.get_name(x),
|
||||
@@ -80,13 +77,13 @@ inline void build_kernel(
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
if (dynamic_dims) {
|
||||
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
||||
os += std::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
|
||||
}
|
||||
|
||||
// The thread index in the whole grid
|
||||
@@ -99,15 +96,15 @@ inline void build_kernel(
|
||||
// a third grid dimension
|
||||
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
|
||||
} else if (work_per_thread > 1) {
|
||||
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
os += fmt::format(
|
||||
os += std::format(" constexpr int N_ = {0};\n", work_per_thread);
|
||||
os += std::format(
|
||||
" int xshape = output_shape[{0}];\n",
|
||||
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
} else {
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
|
||||
idx_type);
|
||||
}
|
||||
@@ -122,16 +119,16 @@ inline void build_kernel(
|
||||
auto type_str = get_type_string(x.dtype());
|
||||
std::ostringstream ss;
|
||||
print_constant(ss, x);
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" auto tmp_{0} = static_cast<{1}>({2});\n",
|
||||
xname,
|
||||
get_type_string(x.dtype()),
|
||||
ss.str());
|
||||
} else if (is_scalar(x)) {
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
|
||||
} else if (contiguous) {
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
|
||||
} else {
|
||||
nc_inputs.push_back(x);
|
||||
@@ -141,30 +138,30 @@ inline void build_kernel(
|
||||
// Initialize the indices for non-contiguous inputs
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
|
||||
os += std::format(" {0} index_{1} = ", idx_type, xname);
|
||||
if (ndim == 1) {
|
||||
int offset = i * ndim;
|
||||
os +=
|
||||
fmt::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
|
||||
std::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
|
||||
} else if (ndim == 2) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
"elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
|
||||
idx_type,
|
||||
offset);
|
||||
} else if (ndim == 3) {
|
||||
int offset = i * ndim;
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
"elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = (i + 1) * ndim;
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
|
||||
idx_type,
|
||||
offset - 1,
|
||||
offset - 2);
|
||||
} else {
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
|
||||
idx_type,
|
||||
i);
|
||||
@@ -176,18 +173,18 @@ inline void build_kernel(
|
||||
if (dynamic_dims) {
|
||||
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
|
||||
} else {
|
||||
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
|
||||
os += std::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
|
||||
}
|
||||
os += " uint l = zpos % output_shape[d];\n";
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& xname = namer.get_name(nc_inputs[i]);
|
||||
os += fmt::format(" index_{0} += ", xname);
|
||||
os += std::format(" index_{0} += ", xname);
|
||||
if (dynamic_dims) {
|
||||
os +=
|
||||
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
|
||||
std::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
|
||||
} else {
|
||||
os +=
|
||||
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
|
||||
std::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
|
||||
}
|
||||
}
|
||||
os += " zpos /= output_shape[d];\n }\n";
|
||||
@@ -203,16 +200,16 @@ inline void build_kernel(
|
||||
for (int i = 0; i < nc_inputs.size(); ++i) {
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
|
||||
}
|
||||
|
||||
// Actually write the computation
|
||||
for (auto& x : tape) {
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
|
||||
if (is_static_cast(x.primitive())) {
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
"static_cast<{0}>(tmp_{1});\n",
|
||||
get_type_string(x.dtype()),
|
||||
namer.get_name(x.inputs()[0]));
|
||||
@@ -222,15 +219,15 @@ inline void build_kernel(
|
||||
os += ss.str();
|
||||
os += "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||
os += std::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
|
||||
}
|
||||
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
|
||||
os += std::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
|
||||
}
|
||||
}
|
||||
|
||||
// Write the outputs from tmps
|
||||
for (auto& x : outputs) {
|
||||
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
os += std::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
|
||||
}
|
||||
// Increment indices and close per thread loop
|
||||
if (work_per_thread > 1) {
|
||||
@@ -238,10 +235,10 @@ inline void build_kernel(
|
||||
auto& x = nc_inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
if (!dynamic_dims) {
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
|
||||
} else {
|
||||
os += fmt::format(
|
||||
os += std::format(
|
||||
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
|
||||
}
|
||||
}
|
||||
|
@@ -34,7 +34,7 @@ void explicit_gemm_conv_ND_gpu(
|
||||
int implicit_K = wt.size() / conv_params.O;
|
||||
int implicit_N = conv_params.O;
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape{implicit_M, implicit_K};
|
||||
Shape unfolded_shape{implicit_M, implicit_K};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
@@ -113,7 +113,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
}
|
||||
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
|
||||
Shape unfolded_shape{implicit_M, implicit_K * groups};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
|
||||
@@ -192,12 +192,12 @@ void conv_1D_gpu(
|
||||
bool flip) {
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(2),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1)},
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(2)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
||||
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
||||
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||
/* const int pad[NDIM] = */ {padding[0]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
||||
@@ -541,7 +541,7 @@ void winograd_conv_2D_gpu(
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params,
|
||||
std::vector<array>& copies_w) {
|
||||
std::vector<int> padded_shape = {
|
||||
Shape padded_shape = {
|
||||
conv_params.N,
|
||||
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||
@@ -550,7 +550,7 @@ void winograd_conv_2D_gpu(
|
||||
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
||||
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
||||
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
array zero_arr = array(0, in.dtype());
|
||||
@@ -575,12 +575,16 @@ void winograd_conv_2D_gpu(
|
||||
copies_w.push_back(in_padded);
|
||||
|
||||
MLXConvParams<2> conv_params_updated{
|
||||
/* const int N = */ in_padded.shape(0),
|
||||
/* const int C = */ in_padded.shape(3),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int N = */ static_cast<int>(in_padded.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in_padded.shape(3)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */
|
||||
{static_cast<int>(in_padded.shape(1)),
|
||||
static_cast<int>(in_padded.shape(2))},
|
||||
/* const int wS[NDIM] = */
|
||||
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
||||
/* const int oS[NDIM] = */
|
||||
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
||||
/* const int str[NDIM] = */ {1, 1},
|
||||
/* const int pad[NDIM] = */ {0, 0},
|
||||
/* const int kdil[NDIM] = */ {1, 1},
|
||||
@@ -607,8 +611,8 @@ void winograd_conv_2D_gpu(
|
||||
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
|
||||
|
||||
// Do filter transform
|
||||
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
|
||||
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
|
||||
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
||||
copies_w.push_back(filt_wg);
|
||||
{
|
||||
@@ -634,8 +638,8 @@ void winograd_conv_2D_gpu(
|
||||
}
|
||||
|
||||
// Do input transform
|
||||
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
|
||||
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
|
||||
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
||||
copies_w.push_back(inp_wg);
|
||||
{
|
||||
@@ -661,8 +665,8 @@ void winograd_conv_2D_gpu(
|
||||
}
|
||||
|
||||
// Do batched gemm
|
||||
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
|
||||
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
|
||||
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
||||
copies_w.push_back(out_wg);
|
||||
{
|
||||
@@ -723,12 +727,15 @@ void conv_2D_gpu(
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<2> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(3),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(3)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */
|
||||
{static_cast<int>(in.shape(1)), static_cast<int>(in.shape(2))},
|
||||
/* const int wS[NDIM] = */
|
||||
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
||||
/* const int oS[NDIM] = */
|
||||
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
@@ -800,12 +807,21 @@ void conv_3D_gpu(
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<3> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(4),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
|
||||
/* const int N = */ static_cast<int>(in.shape(0)),
|
||||
/* const int C = */ static_cast<int>(in.shape(4)),
|
||||
/* const int O = */ static_cast<int>(wt.shape(0)),
|
||||
/* const int iS[NDIM] = */
|
||||
{static_cast<int>(in.shape(1)),
|
||||
static_cast<int>(in.shape(2)),
|
||||
static_cast<int>(in.shape(3))},
|
||||
/* const int wS[NDIM] = */
|
||||
{static_cast<int>(wt.shape(1)),
|
||||
static_cast<int>(wt.shape(2)),
|
||||
static_cast<int>(wt.shape(3))},
|
||||
/* const int oS[NDIM] = */
|
||||
{static_cast<int>(out.shape(1)),
|
||||
static_cast<int>(out.shape(2)),
|
||||
static_cast<int>(out.shape(3))},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
|
||||
/* const int kdil[NDIM] = */
|
||||
|
@@ -52,7 +52,9 @@ void copy_gpu_inplace(
|
||||
int64_t inp_offset,
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -80,6 +82,7 @@ void copy_gpu_inplace(
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
}
|
||||
bool dynamic = dynamic_i_offset || dynamic_o_offset;
|
||||
auto& d = metal::device(s.device);
|
||||
int work_per_thread = 1;
|
||||
std::string kernel_name;
|
||||
@@ -107,9 +110,17 @@ void copy_gpu_inplace(
|
||||
if (large) {
|
||||
kernel_name += "large";
|
||||
}
|
||||
if (dynamic) {
|
||||
kernel_name += "_dynamic";
|
||||
if (ctype != CopyType::GeneralGeneral) {
|
||||
throw std::runtime_error(
|
||||
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
|
||||
}
|
||||
}
|
||||
}
|
||||
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
|
||||
auto kernel = get_copy_kernel(d, kernel_name, in, out);
|
||||
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
|
||||
: get_copy_kernel(d, kernel_name, in, out);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
@@ -145,6 +156,18 @@ void copy_gpu_inplace(
|
||||
compute_encoder.set_bytes(ndim, 5);
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
}
|
||||
if (dynamic) {
|
||||
if (dynamic_i_offset) {
|
||||
compute_encoder.set_input_array(*dynamic_i_offset, 6);
|
||||
} else {
|
||||
compute_encoder.set_bytes(0ll, 6);
|
||||
}
|
||||
if (dynamic_o_offset) {
|
||||
compute_encoder.set_input_array(*dynamic_o_offset, 7);
|
||||
} else {
|
||||
compute_encoder.set_bytes(0ll, 7);
|
||||
}
|
||||
}
|
||||
|
||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||
if (thread_group_size != 1024) {
|
||||
@@ -179,13 +202,13 @@ void copy_gpu_inplace(
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& istride,
|
||||
int64_t ioffset,
|
||||
const Strides& i_strides,
|
||||
int64_t i_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), istride, out.strides(), ioffset, 0, ctype, s);
|
||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||
}
|
||||
|
||||
void fill_gpu(const array& val, array& out, const Stream& s) {
|
||||
|
@@ -17,13 +17,15 @@ void copy_gpu_inplace(
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype);
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& src,
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
@@ -31,8 +33,8 @@ void copy_gpu_inplace(
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Strides& istride,
|
||||
int64_t ioffset,
|
||||
const Strides& i_strides,
|
||||
int64_t i_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
|
@@ -651,18 +651,23 @@ device_info() {
|
||||
auto raw_device = device(default_device()).mtl_device();
|
||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||
|
||||
int mib[] = {CTL_HW, HW_MEMSIZE};
|
||||
size_t memsize = 0;
|
||||
size_t length = sizeof(memsize);
|
||||
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
|
||||
|
||||
sysctl(mib, 2, &memsize, &length, NULL, 0);
|
||||
size_t rsrc_limit = 0;
|
||||
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
|
||||
if (rsrc_limit == 0) {
|
||||
rsrc_limit = 499000;
|
||||
}
|
||||
|
||||
return {
|
||||
{"architecture", arch},
|
||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||
{"max_recommended_working_set_size",
|
||||
raw_device->recommendedMaxWorkingSetSize()},
|
||||
{"memory_size", memsize}};
|
||||
{"memory_size", memsize},
|
||||
{"resource_limit", rsrc_limit}};
|
||||
};
|
||||
static auto device_info_ = init_device_info();
|
||||
return device_info_;
|
||||
|
@@ -3,25 +3,20 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
void signal_and_wait(const array& in, const array& out, const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
if (in.event().valid()) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
||||
in.event().value());
|
||||
void signal_and_wait(const Event& e_signal, const Event& e_wait) {
|
||||
if (e_signal.valid()) {
|
||||
encode_signal(e_signal);
|
||||
}
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
encode_wait(e_wait);
|
||||
}
|
||||
|
||||
void AllReduce::eval_gpu(
|
||||
@@ -38,8 +33,12 @@ void AllReduce::eval_gpu(
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto e = Event(stream());
|
||||
e.set_value(1);
|
||||
signal_and_wait(in.event(), e);
|
||||
auto task = [in = in,
|
||||
out = out,
|
||||
e = std::move(e),
|
||||
reduce_type = reduce_type_,
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
@@ -53,11 +52,9 @@ void AllReduce::eval_gpu(
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
out.event().signal();
|
||||
e.signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
signal_and_wait(in, out, stream());
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
@@ -70,15 +67,19 @@ void AllGather::eval_gpu(
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto task = [in = in, out = out, group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
out.event().signal();
|
||||
};
|
||||
auto e = Event(stream());
|
||||
e.set_value(1);
|
||||
signal_and_wait(in.event(), e);
|
||||
|
||||
auto task =
|
||||
[in = in, out = out, e = std::move(e), group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
e.signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
signal_and_wait(in, out, stream());
|
||||
}
|
||||
|
||||
void Send::eval_gpu(
|
||||
@@ -89,27 +90,20 @@ void Send::eval_gpu(
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
move_or_copy(in, out);
|
||||
|
||||
// Schedule an async send on the comm stream
|
||||
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::send(group, in, dst);
|
||||
out.event().signal();
|
||||
distributed::detail::send(group, out, dst);
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
// Encode a signal event for the input but not a wait since we don't need to
|
||||
// wait on the output.
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
// Encode a signal event for the input
|
||||
if (in.event().valid()) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
||||
in.event().value());
|
||||
encode_signal(in.event());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,20 +117,18 @@ void Recv::eval_gpu(
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Schedule an async recv on the comm stream
|
||||
auto task = [out = out, group = group(), src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
auto e = Event(stream());
|
||||
e.set_value(1);
|
||||
|
||||
// Encode a wait event as there is no input for the recv to encode a signal.
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
encode_wait(e);
|
||||
|
||||
// Schedule an async recv on the comm stream
|
||||
auto task =
|
||||
[out = out, e = std::move(e), group = group(), src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
e.signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@@ -6,6 +6,26 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void encode_wait(Event e) {
|
||||
auto& d = metal::device(e.stream().device);
|
||||
d.end_encoding(e.stream().index);
|
||||
auto command_buffer = d.get_command_buffer(e.stream().index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
|
||||
command_buffer->addCompletedHandler(
|
||||
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
|
||||
}
|
||||
|
||||
void encode_signal(Event e) {
|
||||
auto& d = metal::device(e.stream().device);
|
||||
d.end_encoding(e.stream().index);
|
||||
auto command_buffer = d.get_command_buffer(e.stream().index);
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
|
||||
command_buffer->addCompletedHandler(
|
||||
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
|
||||
}
|
||||
|
||||
Event::Event(const Stream& stream) : stream_(stream) {
|
||||
auto dtor = [](void* ptr) {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
|
10
mlx/backend/metal/event.h
Normal file
10
mlx/backend/metal/event.h
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void encode_wait(Event e);
|
||||
|
||||
void encode_signal(Event e);
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,5 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <fmt/format.h>
|
||||
#include <format>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
@@ -20,9 +20,9 @@ std::pair<std::string, std::string> make_index_args(
|
||||
std::ostringstream idx_args;
|
||||
std::ostringstream idx_arr;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_args << fmt::format(
|
||||
idx_args << std::format(
|
||||
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
|
||||
idx_arr << fmt::format("idx{0}", i);
|
||||
idx_arr << std::format("idx{0}", i);
|
||||
if (i < nidx - 1) {
|
||||
idx_args << "\n";
|
||||
idx_arr << ",";
|
||||
@@ -53,19 +53,19 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
bool large_index = nidx && inputs[1].size() > UINT32_MAX;
|
||||
bool large_src = src.size() > UINT32_MAX;
|
||||
bool large_out = out.size() > UINT32_MAX;
|
||||
bool large_index = nidx && inputs[1].size() > INT32_MAX;
|
||||
bool large_src = src.size() > INT32_MAX;
|
||||
bool large_out = out.size() > INT32_MAX;
|
||||
bool large = large_index || large_src || large_out;
|
||||
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
std::string kernel_name = fmt::format(
|
||||
std::string kernel_name = std::format(
|
||||
"gather{0}{1}_{2}_{3}_{4}",
|
||||
type_to_name(out),
|
||||
idx_type_name,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint");
|
||||
large ? "int64_t" : "int");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
@@ -77,7 +77,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
// Index dimension specializations
|
||||
kernel_source += fmt::format(
|
||||
kernel_source += std::format(
|
||||
gather_kernels,
|
||||
type_to_name(out) + idx_type_name,
|
||||
out_type_str,
|
||||
@@ -86,7 +86,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_args,
|
||||
idx_arr,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint");
|
||||
large ? "int64_t" : "int");
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
@@ -234,11 +234,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
break;
|
||||
}
|
||||
auto upd_contig = upd.flags().row_contiguous;
|
||||
bool large_out = out.size() > UINT32_MAX;
|
||||
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
|
||||
bool large_upd = upd.size() > UINT32_MAX;
|
||||
bool large_out = out.size() > INT32_MAX;
|
||||
bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
|
||||
bool large_upd = upd.size() > INT32_MAX;
|
||||
bool large = large_out || large_idx || large_upd;
|
||||
std::string kernel_name = fmt::format(
|
||||
std::string kernel_name = std::format(
|
||||
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
|
||||
type_to_name(out),
|
||||
idx_type_name,
|
||||
@@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nidx,
|
||||
upd_contig ? "updc_true" : "updc_false",
|
||||
nwork,
|
||||
large ? "int64_t" : "uint");
|
||||
large ? "int64_t" : "int");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
@@ -275,11 +275,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
break;
|
||||
}
|
||||
if (reduce_type_ != Scatter::None) {
|
||||
op_type = fmt::format(fmt::runtime(op_type), out_type_str);
|
||||
op_type = std::vformat(op_type, std::make_format_args(out_type_str));
|
||||
}
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
kernel_source += fmt::format(
|
||||
kernel_source += std::format(
|
||||
scatter_kernels,
|
||||
type_to_name(out) + idx_type_name + "_" + op_name,
|
||||
out_type_str,
|
||||
@@ -290,7 +290,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_arr,
|
||||
upd_contig,
|
||||
nwork,
|
||||
large ? "int64_t" : "uint");
|
||||
large ? "int64_t" : "int");
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
|
@@ -1,25 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view gemv_masked_kernel = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
|
||||
const device {itype}* mat [[buffer(0)]],
|
||||
const device {itype}* in_vec [[buffer(1)]],
|
||||
device {itype}* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device {outm_t}* out_mask [[buffer(20)]],
|
||||
const device {opm_t}* mat_mask [[buffer(21)]],
|
||||
const device {opm_t}* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
@@ -1,32 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view steel_conv_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_conv_general_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
@@ -1,106 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view steel_gemm_fused_kernels = R"(
|
||||
template [[host_name("{name}")]]
|
||||
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
|
||||
const device {itype} *A [[buffer(0)]],
|
||||
const device {itype} *B [[buffer(1)]],
|
||||
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
|
||||
device {itype} *D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_masked_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
block_masked_gemm<
|
||||
{itype},
|
||||
{outmasktype},
|
||||
{opmasktype},
|
||||
{bm},
|
||||
{bn},
|
||||
{bk},
|
||||
{wm},
|
||||
{wn},
|
||||
{trans_a},
|
||||
{trans_b},
|
||||
{mn_aligned},
|
||||
{k_aligned}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const device {outmasktype}* out_mask [[buffer(10)]],
|
||||
const device {opmasktype}* lhs_mask [[buffer(11)]],
|
||||
const device {opmasktype}* rhs_mask [[buffer(12)]],
|
||||
const constant int* mask_strides [[buffer(13)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk<
|
||||
{itype},
|
||||
{otype},
|
||||
{bm},
|
||||
{bn},
|
||||
{bk},
|
||||
{wm},
|
||||
{wn},
|
||||
{trans_a},
|
||||
{trans_b},
|
||||
{mn_aligned},
|
||||
{k_aligned}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {otype}* C [[buffer(2)]],
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk_accum<{atype}, {otype}>(
|
||||
const device {atype}* C_split [[buffer(0)]],
|
||||
device {otype}* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk_accum_axpby<{atype}, {otype}>(
|
||||
const device {atype}* C_split [[buffer(0)]],
|
||||
device {otype}* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
const device {otype}* C [[buffer(5)]],
|
||||
const constant int& ldc [[buffer(6)]],
|
||||
const constant int& fdc [[buffer(7)]],
|
||||
const constant float& alpha [[buffer(8)]],
|
||||
const constant float& beta [[buffer(9)]],
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
)";
|
@@ -1,16 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string op_name(const array& arr) {
|
||||
@@ -26,7 +21,7 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::arange()
|
||||
<< fmt::format(
|
||||
<< std::format(
|
||||
arange_kernels,
|
||||
kernel_name,
|
||||
get_type_string(out.dtype()));
|
||||
@@ -52,7 +47,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
|
||||
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||
return kernel_source;
|
||||
@@ -74,7 +69,7 @@ void append_binary_kernels(
|
||||
{"vs2", "binary_vs2"},
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g1large", "binary_g_nd1"},
|
||||
{"g2large", "binary_g_nd2"},
|
||||
{"g3large", "binary_g_nd3"},
|
||||
}};
|
||||
@@ -86,11 +81,13 @@ void append_binary_kernels(
|
||||
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
|
||||
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
|
||||
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
|
||||
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
|
||||
}
|
||||
@@ -141,7 +138,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g1large", "ternary_g_nd1"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
{"g3large", "ternary_g_nd3"},
|
||||
}};
|
||||
@@ -150,11 +147,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
|
||||
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
|
||||
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
|
||||
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
|
||||
return kernel_source;
|
||||
@@ -178,7 +177,7 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
@@ -186,19 +185,23 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g1large_" + lib_name, "copy_g_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
|
||||
kernel_source += get_template_definition(
|
||||
"gg1large_" + lib_name, "copy_gg_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
@@ -210,6 +213,38 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_dynamic_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::copy();
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source += get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"ggn2_" + lib_name, "copy_gg_dynamic", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg1large_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg2large_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg3large_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"ggn4large_" + lib_name, "copy_gg_dynamic", in_type, out_type, 4);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@@ -219,7 +254,7 @@ MTL::ComputePipelineState* get_softmax_kernel(
|
||||
auto lib = d.get_library(lib_name, [&] {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::softmax()
|
||||
<< fmt::format(
|
||||
<< std::format(
|
||||
softmax_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
@@ -405,17 +440,17 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_fused()
|
||||
<< fmt::format(
|
||||
steel_gemm_fused_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b);
|
||||
<< get_template_definition(
|
||||
lib_name,
|
||||
"gemm",
|
||||
get_type_string(out.dtype()),
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose_a,
|
||||
transpose_b);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
@@ -440,20 +475,20 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
steel_gemm_splitk_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
<< get_template_definition(
|
||||
lib_name,
|
||||
"gemm_splitk",
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -470,13 +505,12 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
fmt::runtime(
|
||||
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
|
||||
: steel_gemm_splitk_accum_kernels),
|
||||
"name"_a = lib_name,
|
||||
"atype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()));
|
||||
<< get_template_definition(
|
||||
lib_name,
|
||||
axbpy ? "gemm_splitk_accum_axpby"
|
||||
: "gemm_splitk_accum",
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()));
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -507,21 +541,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_masked()
|
||||
<< fmt::format(
|
||||
steel_gemm_masked_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"outmasktype"_a = out_mask_type,
|
||||
"opmasktype"_a = op_mask_type,
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
<< get_template_definition(
|
||||
lib_name,
|
||||
"block_masked_gemm",
|
||||
get_type_string(out.dtype()),
|
||||
out_mask_type,
|
||||
op_mask_type,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -550,20 +584,19 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
auto op_mask_type =
|
||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||
kernel_source << metal::utils() << metal::gemv_masked()
|
||||
<< fmt::format(
|
||||
gemv_masked_kernel,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"outm_t"_a = out_mask_type,
|
||||
"opm_t"_a = op_mask_type,
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"sm"_a = sm,
|
||||
"sn"_a = sn,
|
||||
"tm"_a = tm,
|
||||
"tn"_a = tn,
|
||||
"trans"_a = transpose_mat ? "t_" : "",
|
||||
"nc"_a = contiguous ? "0" : "1");
|
||||
<< get_template_definition(
|
||||
lib_name,
|
||||
(transpose_mat) ? "gemv_t_masked" : "gemv_masked",
|
||||
get_type_string(out.dtype()),
|
||||
out_mask_type,
|
||||
op_mask_type,
|
||||
bm,
|
||||
bn,
|
||||
sm,
|
||||
sn,
|
||||
tm,
|
||||
tn,
|
||||
contiguous ? 0 : 1);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -584,17 +617,17 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
||||
<< fmt::format(
|
||||
steel_conv_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"n_channels"_a = n_channel_specialization,
|
||||
"small_filter"_a = small_filter);
|
||||
<< get_template_definition(
|
||||
lib_name,
|
||||
"implicit_gemm_conv_2d",
|
||||
get_type_string(out.dtype()),
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
n_channel_specialization,
|
||||
small_filter);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -614,15 +647,15 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv()
|
||||
<< metal::steel_conv_general()
|
||||
<< fmt::format(
|
||||
steel_conv_general_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn);
|
||||
<< get_template_definition(
|
||||
lib_name,
|
||||
"implicit_gemm_conv_2d_general",
|
||||
get_type_string(out.dtype()),
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
return kernel_source.str();
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
@@ -1,6 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <format>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
@@ -45,6 +45,12 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_dynamic_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@@ -212,7 +218,7 @@ get_template_definition(std::string name, std::string func, Args... args) {
|
||||
};
|
||||
(add_arg(args), ...);
|
||||
s << ">";
|
||||
return fmt::format(
|
||||
return std::format(
|
||||
"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
|
||||
name,
|
||||
s.str());
|
||||
|
@@ -9,21 +9,21 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
|
@@ -7,21 +7,21 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
|
@@ -161,3 +161,78 @@ template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
idx.y += dst_xstride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_dynamic_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
constant const int64_t& src_offset [[buffer(6)]],
|
||||
constant const int64_t& dst_offset [[buffer(7)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
|
||||
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_dynamic_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int64_t& src_offset [[buffer(6)]],
|
||||
constant const int64_t& dst_offset [[buffer(7)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
|
||||
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_dynamic_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int64_t& src_offset [[buffer(6)]],
|
||||
constant const int64_t& dst_offset [[buffer(7)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
|
||||
dst[dst_idx + dst_offset] = src[src_idx + src_offset];
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N = 1, typename IdxT = int64_t>
|
||||
[[kernel]] void copy_gg_dynamic(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
constant const int64_t& src_offset [[buffer(6)]],
|
||||
constant const int64_t& dst_offset [[buffer(7)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
src += src_offset;
|
||||
dst += dst_offset;
|
||||
auto idx = elem_to_loc_2_nd<IdxT>(
|
||||
{N * index.x, index.y, index.z},
|
||||
src_shape,
|
||||
src_strides,
|
||||
dst_strides,
|
||||
ndim);
|
||||
if (N == 1) {
|
||||
dst[idx.y] = src[idx.x];
|
||||
return;
|
||||
}
|
||||
IdxT src_xstride = src_strides[ndim - 1];
|
||||
IdxT dst_xstride = dst_strides[ndim - 1];
|
||||
auto xshape = src_shape[ndim - 1];
|
||||
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
||||
dst[idx.y] = src[idx.x];
|
||||
idx.x += src_xstride;
|
||||
idx.y += dst_xstride;
|
||||
}
|
||||
}
|
||||
|
@@ -4,29 +4,40 @@
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/copy.h"
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
|
||||
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
|
||||
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
|
||||
instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
|
||||
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4)
|
||||
|
||||
#define instantiate_copy_same(tname, type) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, type, type, int) \
|
||||
instantiate_kernel("ggn2_copy" #tname, copy_gg, type, type, 2, int) \
|
||||
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \
|
||||
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \
|
||||
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \
|
||||
instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4) \
|
||||
instantiate_kernel("gg1_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type, int) \
|
||||
instantiate_kernel("gg2_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type, int) \
|
||||
instantiate_kernel("gg3_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type, int) \
|
||||
instantiate_kernel("ggn2_dynamic_copy" #tname, copy_gg_dynamic, type, type, 2, int) \
|
||||
instantiate_kernel("gg1large_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type) \
|
||||
instantiate_kernel("gg2large_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type) \
|
||||
instantiate_kernel("gg3large_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type) \
|
||||
instantiate_kernel("ggn4large_dynamic_copy" #tname, copy_gg_dynamic, type, type, 4)
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_same(itname ##itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
||||
|
@@ -1323,13 +1323,14 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
||||
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||
if (batched) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
@@ -1374,13 +1375,14 @@ template <typename T, int group_size, int bits, bool batched>
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
if (batched) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
@@ -1425,13 +1427,14 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
if (batched) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
@@ -1476,13 +1479,14 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
if (batched) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
@@ -1527,13 +1531,14 @@ template <typename T, const int group_size, const int bits, int split_k = 32>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
@@ -1706,6 +1711,7 @@ template <typename T, int group_size, int bits>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
@@ -1714,7 +1720,7 @@ template <typename T, int group_size, int bits>
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
batch_ndims,
|
||||
batch_shape,
|
||||
lhs_strides,
|
||||
@@ -1767,6 +1773,7 @@ template <typename T, int group_size, int bits>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
@@ -1775,7 +1782,7 @@ template <typename T, int group_size, int bits>
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
batch_ndims,
|
||||
batch_shape,
|
||||
lhs_strides,
|
||||
@@ -1828,6 +1835,7 @@ template <typename T, int group_size, int bits>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
int M = x_shape[x_batch_ndims];
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
@@ -1836,7 +1844,7 @@ template <typename T, int group_size, int bits>
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
y,
|
||||
out_vec_size,
|
||||
out_vec_size * M,
|
||||
batch_ndims,
|
||||
batch_shape,
|
||||
lhs_strides,
|
||||
|
@@ -53,10 +53,10 @@ instantiate_init_min_max(max, Max)
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
itype, otype, op, int, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, uint, dim) \
|
||||
itype, otype, op, int, dim) \
|
||||
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, int64_t, dim) \
|
||||
@@ -67,7 +67,7 @@ instantiate_init_min_max(max, Max)
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
itype, otype, op, int, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, int64_t, dim, bm, bn)
|
||||
@@ -75,7 +75,7 @@ instantiate_init_min_max(max, Max)
|
||||
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
itype, otype, op, int, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, int64_t, dim, bm, bn)
|
||||
@@ -95,7 +95,7 @@ instantiate_init_min_max(max, Max)
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
itype, otype, op, int, dim) \
|
||||
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, int64_t, dim)
|
||||
@@ -103,7 +103,7 @@ instantiate_init_min_max(max, Max)
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, uint, dim) \
|
||||
itype, otype, op, int, dim) \
|
||||
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, int64_t, dim)
|
||||
|
@@ -4,6 +4,8 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
constant bool has_mask [[function_constant(20)]];
|
||||
|
||||
template <typename T, int D>
|
||||
[[kernel]] void sdpa_vector(
|
||||
const device T* queries [[buffer(0)]],
|
||||
@@ -15,6 +17,9 @@ template <typename T, int D>
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -39,6 +44,9 @@ template <typename T, int D>
|
||||
queries += head_idx * D + simd_lid * elem_per_thread;
|
||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
|
||||
}
|
||||
out += head_idx * D + simd_gid * elem_per_thread;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
@@ -54,34 +62,39 @@ template <typename T, int D>
|
||||
|
||||
// For each key
|
||||
for (int i = simd_gid; i < N; i += BN) {
|
||||
// Read the key
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
if (!has_mask || mask[0]) {
|
||||
// Read the key
|
||||
for (int j = 0; j < elem_per_thread; j++) {
|
||||
k[j] = keys[j];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int j = 0; j < elem_per_thread; j++) {
|
||||
score += q[j] * k[j];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * values[i];
|
||||
// Update the output accumulator
|
||||
for (int j = 0; j < elem_per_thread; j++) {
|
||||
o[j] = o[j] * factor + exp_score * values[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
keys += stride;
|
||||
values += stride;
|
||||
if (has_mask) {
|
||||
mask += BN * mask_seq_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
@@ -126,6 +139,9 @@ template <typename T, int D>
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
@@ -155,6 +171,10 @@ template <typename T, int D>
|
||||
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
|
||||
simd_lid * elem_per_thread;
|
||||
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride +
|
||||
(block_idx * BN + simd_gid) * mask_seq_stride;
|
||||
}
|
||||
sums += head_idx * blocks + block_idx;
|
||||
maxs += head_idx * blocks + block_idx;
|
||||
|
||||
@@ -171,34 +191,39 @@ template <typename T, int D>
|
||||
|
||||
// For each key
|
||||
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
||||
// Read the key
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
if (!has_mask || mask[0]) {
|
||||
// Read the key
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = keys[i];
|
||||
}
|
||||
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
// Compute the i-th score
|
||||
U score = 0;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
score += q[i] * k[i];
|
||||
}
|
||||
score = simd_sum(score);
|
||||
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
// Update the accumulators
|
||||
U new_max = max(max_score, score);
|
||||
U factor = fast::exp(max_score - new_max);
|
||||
U exp_score = fast::exp(score - new_max);
|
||||
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
max_score = new_max;
|
||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * values[i];
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = o[i] * factor + exp_score * values[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Move the pointers to the next kv
|
||||
keys += blocks * stride;
|
||||
values += blocks * stride;
|
||||
if (has_mask) {
|
||||
mask += BN * blocks * mask_seq_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Each thread has a partial part of the output so we need to combine them.
|
||||
|
@@ -8,17 +8,17 @@
|
||||
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, uint) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
|
||||
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \
|
||||
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||
|
||||
#define instantiate_ternary_types(op) \
|
||||
instantiate_ternary_all(op, bool_, bool) \
|
||||
|
@@ -9,19 +9,19 @@
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel( \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
|
||||
instantiate_kernel( \
|
||||
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
#define instantiate_unary_all_same(op, tname, type) \
|
||||
instantiate_unary_all(op, tname, tname, type, type)
|
||||
|
||||
#define instantiate_unary_float(op) \
|
||||
#define instantiate_unary_float(op) \
|
||||
instantiate_unary_all_same(op, float16, half) \
|
||||
instantiate_unary_all_same(op, float32, float) \
|
||||
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_unary_types(op) \
|
||||
#define instantiate_unary_types(op) \
|
||||
instantiate_unary_all_same(op, bool_, bool) \
|
||||
instantiate_unary_all_same(op, uint8, uint8_t) \
|
||||
instantiate_unary_all_same(op, uint16, uint16_t) \
|
||||
|
@@ -91,21 +91,7 @@ struct Limits<complex64_t> {
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const int64_t* strides,
|
||||
int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
int64_t elem,
|
||||
IdxT elem,
|
||||
constant const int* shape,
|
||||
constant const int64_t* strides,
|
||||
int ndim) {
|
||||
@@ -187,9 +173,12 @@ METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
|
||||
constant const int64_t* c_strides,
|
||||
int ndim) {
|
||||
vec<IdxT, 3> loc = {
|
||||
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
|
||||
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
|
||||
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
|
||||
IdxT(elem.x * IdxT(a_strides[ndim - 1])) +
|
||||
IdxT(elem.y * IdxT(a_strides[ndim - 2])),
|
||||
IdxT(elem.x * IdxT(b_strides[ndim - 1])) +
|
||||
IdxT(elem.y * IdxT(b_strides[ndim - 2])),
|
||||
IdxT(elem.x * IdxT(c_strides[ndim - 1])) +
|
||||
IdxT(elem.y * IdxT(c_strides[ndim - 2]))};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * IdxT(a_strides[d]);
|
||||
|
@@ -47,7 +47,11 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
}
|
||||
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
try {
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
} catch (const std::exception& error) {
|
||||
abort_with_exception(error);
|
||||
}
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
|
@@ -56,6 +56,14 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_dynamic_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const array&) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
@@ -4,11 +4,13 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@@ -25,6 +27,78 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
||||
enc.set_bytes(step, 1);
|
||||
}
|
||||
|
||||
void reshape(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
array compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes,
|
||||
Stream s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Kernel to compute offset here.
|
||||
array offset({1}, int64, nullptr, {});
|
||||
bool donate = indices.is_donatable() &&
|
||||
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
||||
if (donate) {
|
||||
offset.move_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
||||
}
|
||||
d.add_temporary(offset, s.index);
|
||||
|
||||
auto dtype = indices.dtype();
|
||||
std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype);
|
||||
auto lib = d.get_library(lib_name, [dtype]() {
|
||||
return std::format(
|
||||
R"(
|
||||
[[kernel]] void compute_dynamic_offset_{0}(
|
||||
constant const {1}* indices [[buffer(0)]],
|
||||
device int64_t& offset [[buffer(1)]],
|
||||
constant const int64_t* strides [[buffer(2)]],
|
||||
constant const int* axes [[buffer(3)]],
|
||||
constant const int& n_axes [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {{
|
||||
int64_t acc = 0;
|
||||
for (int i = 0; i < n_axes; ++i) {{
|
||||
acc += indices[i] * strides[axes[i]];
|
||||
}}
|
||||
offset = acc;
|
||||
}})",
|
||||
type_to_name(dtype),
|
||||
get_type_string(dtype));
|
||||
});
|
||||
auto kernel = d.get_kernel(lib_name, lib);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(donate ? offset : indices, 0);
|
||||
compute_encoder.set_output_array(offset, 1);
|
||||
compute_encoder.set_vector_bytes(strides, 2);
|
||||
compute_encoder.set_vector_bytes(axes, 3);
|
||||
int n_axes = axes.size();
|
||||
compute_encoder.set_bytes(n_axes, 4);
|
||||
MTL::Size dims = MTL::Size(1, 1, 1);
|
||||
compute_encoder.dispatch_threads(dims, dims);
|
||||
return offset;
|
||||
}
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
@@ -167,6 +241,10 @@ void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
concatenate_gpu(inputs, out, axis_, stream());
|
||||
}
|
||||
@@ -211,6 +289,18 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
copy_gpu(in, out, ctype);
|
||||
}
|
||||
|
||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto read_task = [out = out,
|
||||
@@ -226,18 +316,17 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
read_task();
|
||||
return;
|
||||
}
|
||||
|
||||
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
|
||||
auto signal_task = [out = out, fut = std::move(fut)]() {
|
||||
|
||||
auto e = Event(stream());
|
||||
e.set_value(1);
|
||||
encode_wait(e);
|
||||
auto signal_task = [e = std::move(e), fut = std::move(fut)]() mutable {
|
||||
fut.wait();
|
||||
out.event().signal();
|
||||
e.signal();
|
||||
};
|
||||
scheduler::enqueue(io_stream(), std::move(signal_task));
|
||||
auto& d = metal::device(stream().device);
|
||||
d.end_encoding(stream().index);
|
||||
auto command_buffer = d.get_command_buffer(stream().index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -305,26 +394,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
stream());
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
reshape(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Split::eval_gpu(
|
||||
@@ -344,6 +414,72 @@ void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
slice_gpu(in, out, start_indices_, strides_, stream());
|
||||
}
|
||||
|
||||
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& start = inputs[1];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto s = stream();
|
||||
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ out.shape(),
|
||||
/* const Strides& i_strides = */ in.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::optional<array>& dynamic_i_offset = */ in_offset,
|
||||
/* const std::optional<array>& dynamic_o_offset = */ std::nullopt);
|
||||
}
|
||||
|
||||
void DynamicSliceUpdate::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
auto& start_indices = inputs[2];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
move_or_copy(in, out);
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy or donate input to output
|
||||
auto s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
|
||||
|
||||
auto out_offset =
|
||||
compute_dynamic_offset(start_indices, out.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ upd.shape(),
|
||||
/* const Strides& i_strides = */ upd.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ s,
|
||||
/* const std::optional<array>& dynamic_i_offset = */ std::nullopt,
|
||||
/* const std::optional<array>& dynamic_o_offset = */ out_offset);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
@@ -359,13 +495,11 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if materialization is needed
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
@@ -381,6 +515,10 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const Stream& s = */ stream());
|
||||
}
|
||||
|
||||
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
@@ -424,7 +562,7 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// - type size is the same
|
||||
// - type size is smaller and the last axis is contiguous
|
||||
// - the entire array is row contiguous
|
||||
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
|
||||
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
|
||||
|
@@ -393,7 +393,7 @@ void row_reduce_small(
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
const std::string func_name = "row_reduce_small";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
bool large = in.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@@ -411,7 +411,7 @@ void row_reduce_small(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
large ? "size_t" : "int",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -490,7 +490,7 @@ void row_reduce_looped(
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
const std::string func_name = "row_reduce_looped";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
bool large = in.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@@ -508,7 +508,7 @@ void row_reduce_looped(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
large ? "size_t" : "int",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -574,7 +574,7 @@ void strided_reduce_small(
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
const std::string func_name = "col_reduce_small";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
bool large = in.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@@ -592,7 +592,7 @@ void strided_reduce_small(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
large ? "size_t" : "int",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -635,7 +635,7 @@ void strided_reduce_longcolumn(
|
||||
}
|
||||
|
||||
// Prepare the temporary accumulator
|
||||
std::vector<int> intermediate_shape;
|
||||
Shape intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.push_back(outer_blocks);
|
||||
intermediate_shape.insert(
|
||||
@@ -665,7 +665,7 @@ void strided_reduce_longcolumn(
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
std::string func_name = "col_reduce_longcolumn";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
bool large = in.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@@ -683,7 +683,7 @@ void strided_reduce_longcolumn(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "int64_t" : "uint",
|
||||
large ? "int64_t" : "int",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@@ -706,7 +706,7 @@ void strided_reduce_longcolumn(
|
||||
// Set the 2nd kernel
|
||||
func_name = "col_reduce_looped";
|
||||
kname = func_name;
|
||||
large = intermediate.size() > UINT32_MAX;
|
||||
large = intermediate.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@@ -718,7 +718,7 @@ void strided_reduce_longcolumn(
|
||||
op_name,
|
||||
intermediate.dtype(),
|
||||
out_type,
|
||||
large ? "int64_t" : "uint",
|
||||
large ? "int64_t" : "int",
|
||||
1,
|
||||
32,
|
||||
32);
|
||||
@@ -760,7 +760,7 @@ void strided_reduce_looped(
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
std::string func_name = "col_reduce_looped";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
bool large = in.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@@ -782,7 +782,7 @@ void strided_reduce_looped(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "int64_t" : "uint",
|
||||
large ? "int64_t" : "int",
|
||||
n,
|
||||
BM,
|
||||
BN);
|
||||
@@ -806,7 +806,7 @@ void strided_reduce_2pass(
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
|
||||
// Prepare the temporary accumulator
|
||||
std::vector<int> intermediate_shape;
|
||||
Shape intermediate_shape;
|
||||
intermediate_shape.reserve(out.ndim() + 1);
|
||||
intermediate_shape.push_back(32);
|
||||
intermediate_shape.insert(
|
||||
@@ -837,7 +837,7 @@ void strided_reduce_2pass(
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
std::string func_name = "col_reduce_2pass";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
bool large = in.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@@ -859,7 +859,7 @@ void strided_reduce_2pass(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "int64_t" : "uint",
|
||||
large ? "int64_t" : "int",
|
||||
n,
|
||||
BM,
|
||||
BN);
|
||||
@@ -882,7 +882,7 @@ void strided_reduce_2pass(
|
||||
// Set the 2nd kernel
|
||||
func_name = "col_reduce_looped";
|
||||
kname = func_name;
|
||||
large = intermediate.size() > UINT32_MAX;
|
||||
large = intermediate.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@@ -894,7 +894,7 @@ void strided_reduce_2pass(
|
||||
op_name,
|
||||
intermediate.dtype(),
|
||||
out_type,
|
||||
large ? "int64_t" : "uint",
|
||||
large ? "int64_t" : "int",
|
||||
1,
|
||||
32,
|
||||
32);
|
||||
|
@@ -66,7 +66,7 @@ void RoPE::eval_gpu(
|
||||
// Special case for inference (single time step and contiguous)
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
|
||||
bool with_freqs = inputs.size() == 2;
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (single ? "single_" : "")
|
||||
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
|
||||
@@ -78,7 +78,7 @@ void RoPE::eval_gpu(
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(offset_, 2);
|
||||
compute_encoder.set_input_array(inputs[1], 2);
|
||||
compute_encoder.set_bytes(scale_, 3);
|
||||
|
||||
size_t n_batch = in.size() / mat_size;
|
||||
@@ -101,7 +101,7 @@ void RoPE::eval_gpu(
|
||||
}
|
||||
|
||||
if (with_freqs) {
|
||||
auto& freqs = inputs[1];
|
||||
auto& freqs = inputs[2];
|
||||
compute_encoder.set_input_array(freqs, 10);
|
||||
auto freq_stride = freqs.strides()[0];
|
||||
compute_encoder.set_bytes(freq_stride, 11);
|
||||
|
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
@@ -116,7 +115,8 @@ void sdpa_vector(
|
||||
const array& k,
|
||||
const array& v,
|
||||
array& out,
|
||||
float scale) {
|
||||
float scale,
|
||||
const std::optional<array>& mask) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
@@ -134,9 +134,16 @@ void sdpa_vector(
|
||||
MTL::Size group_dims(1024, 1, 1);
|
||||
MTL::Size grid_dims(1, B, 1);
|
||||
|
||||
bool has_mask = mask.has_value();
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? "_mask" : "_nomask";
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname);
|
||||
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
@@ -149,6 +156,14 @@ void sdpa_vector(
|
||||
compute_encoder.set_bytes(k_stride, 6);
|
||||
compute_encoder.set_bytes(v_stride, 7);
|
||||
compute_encoder.set_bytes(scale, 8);
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 9);
|
||||
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
|
||||
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
|
||||
compute_encoder.set_bytes(seq_stride, 10);
|
||||
compute_encoder.set_bytes(head_stride, 11);
|
||||
}
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
@@ -161,7 +176,8 @@ void sdpa_vector_2pass(
|
||||
const array& k,
|
||||
const array& v,
|
||||
array& out,
|
||||
float scale) {
|
||||
float scale,
|
||||
const std::optional<array>& mask) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
@@ -198,9 +214,17 @@ void sdpa_vector_2pass(
|
||||
d.add_temporary(sums, s.index);
|
||||
d.add_temporary(maxs, s.index);
|
||||
|
||||
bool has_mask = mask.has_value();
|
||||
metal::MTLFCList func_consts = {
|
||||
{&has_mask, MTL::DataType::DataTypeBool, 20},
|
||||
};
|
||||
std::string hash_name = kname;
|
||||
hash_name += has_mask ? "_mask" : "_nomask";
|
||||
|
||||
// Get the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname);
|
||||
auto kernel = d.get_kernel(kname, "mlx", hash_name, func_consts);
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
@@ -215,6 +239,14 @@ void sdpa_vector_2pass(
|
||||
compute_encoder.set_bytes(k_stride, 8);
|
||||
compute_encoder.set_bytes(v_stride, 9);
|
||||
compute_encoder.set_bytes(scale, 10);
|
||||
if (has_mask) {
|
||||
auto& m = *mask;
|
||||
compute_encoder.set_input_array(m, 11);
|
||||
int32_t seq_stride = m.ndim() >= 1 ? m.strides().back() : 0;
|
||||
int32_t head_stride = m.ndim() >= 3 ? *(m.strides().end() - 3) : 0;
|
||||
compute_encoder.set_bytes(seq_stride, 12);
|
||||
compute_encoder.set_bytes(head_stride, 13);
|
||||
}
|
||||
|
||||
// Launch
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
@@ -247,8 +279,6 @@ void sdpa_vector_2pass(
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
@@ -296,6 +326,8 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) == 1) {
|
||||
const auto& q = copy_unless(is_contiguous, q_pre);
|
||||
// 1, heads, seq_len, head_dim
|
||||
// mask [1, query_heads, 1, seq_len]
|
||||
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||
|
||||
@@ -306,15 +338,18 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
}
|
||||
|
||||
auto mask =
|
||||
inputs.size() > 3 ? std::optional<array>{inputs[3]} : std::nullopt;
|
||||
|
||||
// We route to the 2 pass fused attention if
|
||||
// - The device is large and the sequence length long
|
||||
// - The sequence length is even longer and we have gqa
|
||||
char devc = d.get_architecture().back();
|
||||
if ((devc == 'd' && k.shape(2) >= 1024) ||
|
||||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
|
||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_);
|
||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_, mask);
|
||||
} else {
|
||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||
sdpa_vector(s, d, q, k, v, o, scale_, mask);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -14,35 +14,18 @@ void slice_gpu(
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
// Calculate out strides and initial offset
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
auto copy_needed =
|
||||
std::any_of(strides.begin(), strides.end(), [](auto i) { return i < 0; });
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
/* const array& in = */ in,
|
||||
/* array& out = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General,
|
||||
/* const Stream& s = */ s);
|
||||
} else {
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < strides.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < strides.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
void concatenate_gpu(
|
||||
@@ -80,8 +63,8 @@ void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
std::vector<int> axes,
|
||||
std::vector<int> low_pad_size,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s) {
|
||||
// Fill output with val
|
||||
fill_gpu(val, out, s);
|
||||
|
@@ -23,8 +23,8 @@ void pad_gpu(
|
||||
const array& in,
|
||||
const array& val,
|
||||
array& out,
|
||||
std::vector<int> axes,
|
||||
std::vector<int> low_pad_size,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -82,9 +82,17 @@ void single_block_sort(
|
||||
compute_encoder.set_bytes(out_stride_segment_axis, 6);
|
||||
} else {
|
||||
compute_encoder.set_bytes(nc_dim, 5);
|
||||
compute_encoder.set_vector_bytes(nc_shape, 6);
|
||||
compute_encoder.set_vector_bytes(in_nc_str, 7);
|
||||
compute_encoder.set_vector_bytes(out_nc_str, 8);
|
||||
if (nc_shape.empty()) {
|
||||
int shape = 0;
|
||||
int64_t stride = 0;
|
||||
compute_encoder.set_bytes(shape, 6);
|
||||
compute_encoder.set_bytes(stride, 7);
|
||||
compute_encoder.set_bytes(stride, 8);
|
||||
} else {
|
||||
compute_encoder.set_vector_bytes(nc_shape, 6);
|
||||
compute_encoder.set_vector_bytes(in_nc_str, 7);
|
||||
compute_encoder.set_vector_bytes(out_nc_str, 8);
|
||||
}
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
|
@@ -36,15 +36,15 @@ void ternary_op_gpu_inplace(
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
|
||||
|
||||
bool large = out.data_size() > UINT_MAX;
|
||||
bool large;
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread;
|
||||
if (topt == TernaryOpType::General) {
|
||||
large |=
|
||||
(a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
||||
c.data_size() > UINT32_MAX);
|
||||
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
c.data_size() > INT32_MAX || out.size() > INT32_MAX;
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > INT32_MAX;
|
||||
work_per_thread = 1;
|
||||
}
|
||||
std::string kernel_name;
|
||||
|
@@ -36,9 +36,11 @@ void unary_op_gpu_inplace(
|
||||
auto [shape, strides] = maybe_collapse();
|
||||
int ndim = shape.size();
|
||||
size_t nthreads = contig ? in.data_size() : in.size();
|
||||
bool large = in.data_size() > UINT32_MAX;
|
||||
bool large;
|
||||
if (!contig) {
|
||||
large |= in.size() > UINT32_MAX;
|
||||
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
|
||||
} else {
|
||||
large = in.data_size() > UINT32_MAX;
|
||||
}
|
||||
int work_per_thread = !contig && large ? 4 : 1;
|
||||
std::string kernel_name;
|
||||
|
@@ -35,6 +35,7 @@ NO_CPU(AsStrided)
|
||||
NO_CPU(BitwiseBinary)
|
||||
NO_CPU(BlockMaskedMM)
|
||||
NO_CPU(Broadcast)
|
||||
NO_CPU(BroadcastAxes)
|
||||
NO_CPU(Ceil)
|
||||
NO_CPU(Cholesky)
|
||||
NO_CPU(Concatenate)
|
||||
@@ -48,6 +49,8 @@ NO_CPU_MULTI(CustomTransforms)
|
||||
NO_CPU_MULTI(Depends)
|
||||
NO_CPU(Divide)
|
||||
NO_CPU_MULTI(DivMod)
|
||||
NO_CPU(DynamicSlice)
|
||||
NO_CPU(DynamicSliceUpdate)
|
||||
NO_CPU(NumberOfElements)
|
||||
NO_CPU(Remainder)
|
||||
NO_CPU_MULTI(Eigh)
|
||||
@@ -55,8 +58,10 @@ NO_CPU(Equal)
|
||||
NO_CPU(Erf)
|
||||
NO_CPU(ErfInv)
|
||||
NO_CPU(Exp)
|
||||
NO_CPU(ExpandDims)
|
||||
NO_CPU(Expm1)
|
||||
NO_CPU(FFT)
|
||||
NO_CPU(Flatten)
|
||||
NO_CPU(Floor)
|
||||
NO_CPU(Full)
|
||||
NO_CPU(Gather)
|
||||
@@ -104,6 +109,7 @@ NO_CPU(Softmax)
|
||||
NO_CPU(Sort)
|
||||
NO_CPU_MULTI(Split)
|
||||
NO_CPU(Square)
|
||||
NO_CPU(Squeeze)
|
||||
NO_CPU(Sqrt)
|
||||
NO_CPU(StopGradient)
|
||||
NO_CPU(Subtract)
|
||||
@@ -111,6 +117,7 @@ NO_CPU_MULTI(SVD)
|
||||
NO_CPU(Tan)
|
||||
NO_CPU(Tanh)
|
||||
NO_CPU(Transpose)
|
||||
NO_CPU(Unflatten)
|
||||
NO_CPU(Inverse)
|
||||
NO_CPU(View)
|
||||
|
||||
|
@@ -36,6 +36,7 @@ NO_GPU(AsStrided)
|
||||
NO_GPU(BitwiseBinary)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Broadcast)
|
||||
NO_GPU(BroadcastAxes)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Concatenate)
|
||||
@@ -49,14 +50,18 @@ NO_GPU_MULTI(CustomTransforms)
|
||||
NO_GPU_MULTI(Depends)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(NumberOfElements)
|
||||
NO_GPU(Remainder)
|
||||
NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
NO_GPU(Exp)
|
||||
NO_GPU(ExpandDims)
|
||||
NO_GPU(Expm1)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Flatten)
|
||||
NO_GPU(Floor)
|
||||
NO_GPU(Full)
|
||||
NO_GPU(Gather)
|
||||
@@ -104,6 +109,7 @@ NO_GPU(Softmax)
|
||||
NO_GPU(Sort)
|
||||
NO_GPU_MULTI(Split)
|
||||
NO_GPU(Square)
|
||||
NO_GPU(Squeeze)
|
||||
NO_GPU(Sqrt)
|
||||
NO_GPU(StopGradient)
|
||||
NO_GPU(Subtract)
|
||||
@@ -111,6 +117,7 @@ NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Transpose)
|
||||
NO_GPU(Unflatten)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
@@ -68,24 +68,7 @@ bool is_reduction(const Primitive& p) {
|
||||
}
|
||||
|
||||
bool is_fusable(const Primitive& p) {
|
||||
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||
|
||||
is_noop(p);
|
||||
}
|
||||
|
||||
bool allows_shapeless(const Primitive& p) {
|
||||
return typeid(p) == typeid(Arange) || typeid(p) == typeid(Compiled) ||
|
||||
is_unary(p) || is_binary(p) || is_noop(p) || is_reduction(p) ||
|
||||
typeid(p) == typeid(Softmax) || typeid(p) == typeid(Sort) ||
|
||||
typeid(p) == typeid(ArgSort) || typeid(p) == typeid(ArgPartition) ||
|
||||
typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) ||
|
||||
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
|
||||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
|
||||
typeid(p) == typeid(Reshape) || typeid(p) == typeid(Matmul) ||
|
||||
typeid(p) == typeid(QuantizedMatmul) ||
|
||||
typeid(p) == typeid(fast::AffineQuantize) ||
|
||||
typeid(p) == typeid(fast::LayerNorm) ||
|
||||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||
|
||||
typeid(p) == typeid(fast::ScaledDotProductAttention);
|
||||
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p);
|
||||
}
|
||||
|
||||
Compiled::Compiled(
|
||||
@@ -172,9 +155,6 @@ CompileMode& compile_mode() {
|
||||
return compile_mode_;
|
||||
}
|
||||
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
||||
// Helper like below but only merges the two provided arrays. If the src has
|
||||
// siblings then these won't be merged to the dst.
|
||||
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
|
||||
@@ -303,9 +283,10 @@ CompilerCache& compiler_cache() {
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& inputs) {
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless) {
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
detail::InTracing in_tracing{shapeless};
|
||||
|
||||
// Run the function on placeholder inputs
|
||||
// to get compute graph
|
||||
@@ -369,12 +350,12 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
return {tape, parents_map};
|
||||
}
|
||||
|
||||
// Simplify the tape. Note, this function modifies in-place both the tape and
|
||||
// the parents map to remove orphaned arrays
|
||||
// Simplify the tape. Note, this function modifies in-place both the tape,
|
||||
// the parents map to remove orphaned arrays, and potentially the outputs
|
||||
void compile_simplify(
|
||||
std::vector<array>& tape,
|
||||
ParentsMap& parents_map,
|
||||
const std::vector<array>& outputs,
|
||||
std::vector<array>& outputs,
|
||||
int passes) {
|
||||
// Helpers to identify identical scalars
|
||||
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
||||
@@ -451,6 +432,28 @@ void compile_simplify(
|
||||
}
|
||||
tape = std::move(new_tape);
|
||||
|
||||
// Remove no-ops
|
||||
{
|
||||
std::unordered_map<uintptr_t, array> output_map;
|
||||
for (auto& o : outputs) {
|
||||
output_map.insert({o.id(), o});
|
||||
}
|
||||
for (auto& arr : tape) {
|
||||
if (!arr.has_primitive() || !is_noop(arr.primitive())) {
|
||||
new_tape.push_back(std::move(arr));
|
||||
continue;
|
||||
}
|
||||
merge_one(arr.inputs()[0], arr, parents_map);
|
||||
if (auto it = output_map.find(arr.id()); it != output_map.end()) {
|
||||
it->second = arr.inputs()[0];
|
||||
}
|
||||
}
|
||||
tape = std::move(new_tape);
|
||||
for (auto& o : outputs) {
|
||||
o = output_map.at(o.id());
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, uint32_t> tape_order;
|
||||
for (uint32_t i = 0; i < tape.size(); ++i) {
|
||||
tape_order.insert({tape[i].id(), i});
|
||||
@@ -460,6 +463,7 @@ void compile_simplify(
|
||||
for (auto& o : outputs) {
|
||||
output_set.insert(o.id());
|
||||
}
|
||||
|
||||
// Multi-pass merge only keeping non-orphaned arrays in the tape
|
||||
for (int pass = 0; pass < passes; ++pass) {
|
||||
for (auto& arr : tape) {
|
||||
@@ -748,10 +752,15 @@ std::vector<array> compile_replace(
|
||||
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||
}
|
||||
|
||||
auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); };
|
||||
|
||||
for (auto& a : tape) {
|
||||
// Arrays in the tape without primitives are constants
|
||||
// and can be used directly
|
||||
if (!a.has_primitive()) {
|
||||
// Arrays in the tape without primitives are either:
|
||||
// - inputs, which are already in the map
|
||||
// - constants, which can be used directly
|
||||
// - a load primitive which has no inputs and will become a constant
|
||||
// after the first eval
|
||||
if (!a.has_primitive() || is_load(a.primitive())) {
|
||||
trace_to_real.insert({a.id(), a});
|
||||
} else {
|
||||
// Find real inputs
|
||||
@@ -799,24 +808,6 @@ std::vector<array> compile_replace(
|
||||
return outputs;
|
||||
}
|
||||
|
||||
void compile_validate_shapeless(const std::vector<array>& tape) {
|
||||
for (auto& t : tape) {
|
||||
if (!t.has_primitive()) {
|
||||
continue;
|
||||
}
|
||||
auto& p = t.primitive();
|
||||
if (allows_shapeless(p)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::ostringstream msg;
|
||||
msg << "[compile] Cannot compile primitive ";
|
||||
p.print(msg);
|
||||
msg << " with shapeless enabled.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
bool skip_compile() {
|
||||
return compile_mode() == CompileMode::disabled ||
|
||||
!(compile_available_for_device(default_device()));
|
||||
@@ -856,7 +847,8 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
// Set the constants
|
||||
entry.constants = std::move(constants);
|
||||
// Trace to build the graph
|
||||
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
|
||||
std::tie(entry.inputs, entry.outputs) =
|
||||
compile_trace(fun, inputs, shapeless);
|
||||
|
||||
// DFS the graph and get a tape, and a map of array id to (parent,
|
||||
// position in parent inputs)
|
||||
@@ -876,10 +868,6 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
if (compile_mode() != CompileMode::no_fuse) {
|
||||
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
|
||||
}
|
||||
|
||||
if (shapeless) {
|
||||
compile_validate_shapeless(entry.tape);
|
||||
}
|
||||
}
|
||||
|
||||
// At this point we must have a tape, now replace the placeholders
|
||||
|
@@ -2,7 +2,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/device.h"
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
@@ -22,4 +24,35 @@ void compile_erase(std::uintptr_t fun_id);
|
||||
void compile_clear_cache();
|
||||
|
||||
bool compile_available_for_device(const Device& device);
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless);
|
||||
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
||||
// Traverses the graph to build a tape and a map of array ids to their parents
|
||||
std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& original_inputs);
|
||||
|
||||
// Simplify the tape.
|
||||
void compile_simplify(
|
||||
std::vector<array>& tape,
|
||||
ParentsMap& parents_map,
|
||||
std::vector<array>& outputs,
|
||||
int passes);
|
||||
|
||||
std::vector<array> compile_replace(
|
||||
const std::vector<array>& tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
const std::vector<array>& trace_outputs,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless);
|
||||
|
||||
void compile_validate_shapeless(const std::vector<array>& tape);
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user