Compare commits

..

1 Commits

Author SHA1 Message Date
Awni Hannun
fc6494cac7 register in task submission 2024-11-05 14:42:51 -08:00
438 changed files with 19519 additions and 34927 deletions

View File

@@ -85,19 +85,17 @@ jobs:
name: Install dependencies name: Install dependencies
command: | command: |
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.4.0 pip install nanobind==2.2.0
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \ CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop python3 setup.py develop
- run: - run:
@@ -139,7 +137,7 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.4.0 pip install nanobind==2.2.0
pip install numpy pip install numpy
pip install torch pip install torch
pip install tensorflow pip install tensorflow
@@ -148,9 +146,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
command: | command: |
@@ -164,7 +160,6 @@ jobs:
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run: - run:
name: Build example extension name: Build example extension
command: | command: |
@@ -231,7 +226,7 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.4.0 pip install nanobind==2.2.0
pip install --upgrade setuptools pip install --upgrade setuptools
pip install numpy pip install numpy
pip install twine pip install twine
@@ -296,7 +291,7 @@ jobs:
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install nanobind==2.4.0 pip install nanobind==2.2.0
pip install --upgrade setuptools pip install --upgrade setuptools
pip install numpy pip install numpy
pip install auditwheel pip install auditwheel

3
.gitignore vendored
View File

@@ -76,9 +76,6 @@ build/
*.out *.out
*.app *.app
# Debug symbols
*.pdb
# VSCode # VSCode
.vscode/ .vscode/
.DS_Store .DS_Store

View File

@@ -1,16 +1,15 @@
repos: repos:
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7 rev: v18.1.8
hooks: hooks:
- id: clang-format - id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.1.0 rev: 24.8.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 6.0.0 rev: 5.13.2
hooks: hooks:
- id: isort - id: isort
args: args:

View File

@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals: MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer. - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
- Juarez Bochi: Fixed bug in cross attention. - Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.

View File

@@ -1,24 +1,6 @@
cmake_minimum_required(VERSION 3.25) cmake_minimum_required(VERSION 3.24)
if(NOT MLX_VERSION) project(mlx LANGUAGES C CXX)
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
set(_major ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
set(_minor ${CMAKE_MATCH_1})
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
set(_patch ${CMAKE_MATCH_1})
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
set(MLX_VERSION ${MLX_PROJECT_VERSION})
else()
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
${MLX_VERSION})
endif()
project(
mlx
LANGUAGES C CXX
VERSION ${MLX_PROJECT_VERSION})
# ----------------------------- Setup ----------------------------- # ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
@@ -38,16 +20,22 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON) option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON) option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF) option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.19.3)
endif()
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
message( message(
STATUS STATUS
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}" "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
) )
set(MLX_BUILD_ARM OFF)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC) if(NOT MLX_ENABLE_X64_MAC)
@@ -69,6 +57,10 @@ else()
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.") message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif() endif()
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
# ----------------------------- Lib ----------------------------- # ----------------------------- Lib -----------------------------
include(FetchContent) include(FetchContent)
@@ -97,26 +89,25 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process( execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_SDK_VERSION} LESS 14.0) if(${MACOS_VERSION} LESS 14.0)
message( message(
FATAL_ERROR FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON") "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif() endif()
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}") message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip) https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "") # Get the metal version
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
execute_process( execute_process(
COMMAND COMMAND
zsh "-c" zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL}) FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp) FetchContent_MakeAvailable(metal_cpp)
@@ -124,58 +115,20 @@ elseif(MLX_BUILD_METAL)
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}> mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>) $<INSTALL_INTERFACE:include/metal_cpp>)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()
if(WIN32) add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
if(MSVC)
# GGUF does not build with MSVC.
set(MLX_BUILD_GGUF OFF)
# There is no prebuilt OpenBLAS distribution for MSVC.
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
endif()
# Windows implementation of dlfcn.h APIs.
FetchContent_Declare(
dlfcn-win32
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
GIT_TAG v1.4.1
EXCLUDE_FROM_ALL)
block()
set(BUILD_SHARED_LIBS OFF)
FetchContent_MakeAvailable(dlfcn-win32)
endblock()
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
target_link_libraries(mlx PRIVATE dl)
endif() endif()
if(MLX_BUILD_CPU) if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate) find_library(ACCELERATE_LIBRARY Accelerate)
if(ACCELERATE_LIBRARY) if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else() else()
message(STATUS "Accelerate or arm neon not found, using default backend.") message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif()
if(MLX_BUILD_ACCELERATE)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(MLX_USE_ACCELERATE)
add_compile_definitions(ACCELERATE_NEW_LAPACK)
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
# Download and build OpenBLAS from source code.
FetchContent_Declare(
openblas
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
GIT_TAG v0.3.28
EXCLUDE_FROM_ALL)
set(BUILD_STATIC_LIBS ON) # link statically
set(NOFORTRAN ON) # msvc has no fortran compiler
FetchContent_MakeAvailable(openblas)
target_link_libraries(mlx PRIVATE openblas)
target_include_directories(
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
else()
if(${CMAKE_HOST_APPLE}) if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for # The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead. # openblas instead.
@@ -193,7 +146,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES}) target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old # List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas. # version of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED) find_package(BLAS REQUIRED)
@@ -206,7 +159,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Blas lib " ${BLAS_LIBRARIES}) message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES}) target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
endif() endif()
else() else()
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
@@ -230,14 +183,6 @@ if(MPI_FOUND)
endif() endif()
endif() endif()
message(STATUS "Downloading json")
FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories( target_include_directories(
@@ -261,7 +206,8 @@ if(MLX_BUILD_PYTHON_BINDINGS)
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE nanobind_ROOT) OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif() endif()

View File

@@ -5,35 +5,35 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "time_utils.h" #include "time_utils.h"
namespace mx = mlx::core; using namespace mlx::core;
void time_value_and_grad() { void time_value_and_grad() {
auto x = mx::ones({200, 1000}); auto x = ones({200, 1000});
mx::eval(x); eval(x);
auto fn = [](mx::array x) { auto fn = [](array x) {
for (int i = 0; i < 20; ++i) { for (int i = 0; i < 20; ++i) {
x = mx::log(mx::exp(x)); x = log(exp(x));
} }
return mx::sum(x); return sum(x);
}; };
auto grad_fn = mx::grad(fn); auto grad_fn = grad(fn);
auto independent_value_and_grad = [&]() { auto independent_value_and_grad = [&]() {
auto value = fn(x); auto value = fn(x);
auto dfdx = grad_fn(x); auto dfdx = grad_fn(x);
return std::vector<mx::array>{value, dfdx}; return std::vector<array>{value, dfdx};
}; };
TIME(independent_value_and_grad); TIME(independent_value_and_grad);
auto value_and_grad_fn = mx::value_and_grad(fn); auto value_and_grad_fn = value_and_grad(fn);
auto combined_value_and_grad = [&]() { auto combined_value_and_grad = [&]() {
auto [value, dfdx] = value_and_grad_fn(x); auto [value, dfdx] = value_and_grad_fn(x);
return std::vector<mx::array>{value, dfdx}; return std::vector<array>{value, dfdx};
}; };
TIME(combined_value_and_grad); TIME(combined_value_and_grad);
} }
int main() { int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl; std::cout << "Benchmarks for " << default_device() << std::endl;
time_value_and_grad(); time_value_and_grad();
} }

View File

@@ -4,21 +4,21 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "time_utils.h" #include "time_utils.h"
namespace mx = mlx::core; using namespace mlx::core;
void time_add_op() { void time_add_op() {
std::vector<int> sizes(1, 1); std::vector<int> sizes(1, 1);
for (int i = 0; i < 9; ++i) { for (int i = 0; i < 9; ++i) {
sizes.push_back(10 * sizes.back()); sizes.push_back(10 * sizes.back());
} }
set_default_device(mx::Device::cpu); set_default_device(Device::cpu);
for (auto size : sizes) { for (auto size : sizes) {
auto a = mx::random::uniform({size}); auto a = random::uniform({size});
auto b = mx::random::uniform({size}); auto b = random::uniform({size});
mx::eval(a, b); eval(a, b);
std::cout << "Size " << size << std::endl; std::cout << "Size " << size << std::endl;
TIMEM("cpu", mx::add, a, b, mx::Device::cpu); TIMEM("cpu", add, a, b, Device::cpu);
TIMEM("gpu", mx::add, a, b, mx::Device::gpu); TIMEM("gpu", add, a, b, Device::gpu);
} }
} }

View File

@@ -6,105 +6,105 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "time_utils.h" #include "time_utils.h"
namespace mx = mlx::core; using namespace mlx::core;
void time_irregular_binary_ops_1D() { void time_irregular_binary_ops_1D() {
auto device = mx::default_device(); auto device = default_device();
int size = 1000000; int size = 1000000;
int step = 2; int step = 2;
auto a = mx::random::uniform({size}); auto a = random::uniform({size});
auto b = mx::random::uniform({size}); auto b = random::uniform({size});
mx::eval(a, b); eval(a, b);
a = slice(a, {0}, {size}, {step}); a = slice(a, {0}, {size}, {step});
b = slice(b, {0}, {size}, {step}); b = slice(b, {0}, {size}, {step});
TIMEM("1D strided", mx::add, a, b, device); TIMEM("1D strided", add, a, b, device);
} }
void time_irregular_binary_ops_2D() { void time_irregular_binary_ops_2D() {
auto device = mx::default_device(); auto device = default_device();
int size = 2048; int size = 2048;
auto a = mx::random::uniform({size, size}); auto a = random::uniform({size, size});
auto b = mx::random::uniform({size, size}); auto b = random::uniform({size, size});
mx::eval(a, b); eval(a, b);
TIMEM("2D regular", mx::add, a, b, device); TIMEM("2D regular", add, a, b, device);
b = mx::transpose(b); b = transpose(b);
mx::eval(b); eval(b);
TIMEM("2D mx::transpose", mx::add, a, b, device); TIMEM("2D transpose", add, a, b, device);
b = mx::random::uniform({size}); b = random::uniform({size});
mx::eval(b); eval(b);
TIMEM("2D broadcast dim 0", mx::add, a, b, device); TIMEM("2D broadcast dim 0", add, a, b, device);
b = mx::reshape(b, {size, 1}); b = reshape(b, {size, 1});
mx::eval(b); eval(b);
TIMEM("2D broadcast dim 1", mx::add, a, b, device); TIMEM("2D broadcast dim 1", add, a, b, device);
} }
void time_irregular_binary_ops_3D() { void time_irregular_binary_ops_3D() {
auto device = mx::default_device(); auto device = default_device();
int d0 = 32; int d0 = 32;
int d1 = 512; int d1 = 512;
int d2 = 512; int d2 = 512;
auto a = mx::random::uniform({d0, d1, d2}); auto a = random::uniform({d0, d1, d2});
auto b = mx::random::uniform({d0, d1, d2}); auto b = random::uniform({d0, d1, d2});
TIMEM("3D regular", mx::add, a, b, device); TIMEM("3D regular", add, a, b, device);
b = mx::transpose(b, {0, 2, 1}); b = transpose(b, {0, 2, 1});
TIMEM("3D mx::transpose", mx::add, a, b, device); TIMEM("3D transpose", add, a, b, device);
b = mx::random::uniform({d1, d2}); b = random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", mx::add, a, b, device); TIMEM("3D broadcast dim 0", add, a, b, device);
b = mx::random::uniform({d0, 1, d2}); b = random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", mx::add, a, b, device); TIMEM("3D broadcast dim 1", add, a, b, device);
b = mx::random::uniform({d0, d1, 1}); b = random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", mx::add, a, b, device); TIMEM("3D broadcast dim 2", add, a, b, device);
b = mx::random::uniform({d2}); b = random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device); TIMEM("3D broadcast dims 0, 1", add, a, b, device);
b = mx::random::uniform({d1, 1}); b = random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device); TIMEM("3D broadcast dims 0, 2", add, a, b, device);
b = mx::random::uniform({d0, 1, 1}); b = random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device); TIMEM("3D broadcast dims 1, 2", add, a, b, device);
} }
void time_irregular_binary_ops_4D() { void time_irregular_binary_ops_4D() {
auto device = mx::default_device(); auto device = default_device();
std::vector<int> shape = {8, 8, 512, 512}; std::vector<int> shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape); auto a = random::uniform(shape);
auto b = mx::random::uniform(shape); auto b = random::uniform(shape);
TIMEM("4D regular", mx::add, a, b, device); TIMEM("4D regular", add, a, b, device);
b = mx::transpose(b, {0, 1, 3, 2}); b = transpose(b, {0, 1, 3, 2});
TIMEM("4D mx::transpose", mx::add, a, b, device); TIMEM("4D transpose", add, a, b, device);
std::string om = "4D broadcast dims "; std::string om = "4D broadcast dims ";
for (int i = 0; i < shape.size(); ++i) { for (int i = 0; i < shape.size(); ++i) {
shape[i] = 1; shape[i] = 1;
b = mx::random::uniform(shape); b = random::uniform(shape);
std::ostringstream msg; std::ostringstream msg;
msg << om << i; msg << om << i;
TIMEM(msg.str(), mx::add, a, b, device); TIMEM(msg.str(), add, a, b, device);
for (int j = i + 1; j < shape.size(); ++j) { for (int j = i + 1; j < shape.size(); ++j) {
shape[j] = 1; shape[j] = 1;
std::ostringstream msg; std::ostringstream msg;
msg << om << i << ", " << j; msg << om << i << ", " << j;
b = mx::random::uniform(shape); b = random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device); TIMEM(msg.str(), add, a, b, device);
shape[j] = a.shape(j); shape[j] = a.shape(j);
for (int k = j + 1; k < shape.size(); ++k) { for (int k = j + 1; k < shape.size(); ++k) {
shape[k] = 1; shape[k] = 1;
std::ostringstream msg; std::ostringstream msg;
msg << om << i << ", " << j << ", " << k; msg << om << i << ", " << j << ", " << k;
b = mx::random::uniform(shape); b = random::uniform(shape);
TIMEM(msg.str(), mx::add, a, b, device); TIMEM(msg.str(), add, a, b, device);
shape[k] = a.shape(k); shape[k] = a.shape(k);
} }
} }
@@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() {
} }
void time_irregular_reshape() { void time_irregular_reshape() {
auto device = mx::default_device(); auto device = default_device();
std::vector<int> shape; std::vector<int> shape;
auto reshape_fn = [&shape, device](const mx::array& a) { auto reshape_fn = [&shape, device](const array& a) {
return mx::reshape(a, shape, device); return reshape(a, shape, device);
}; };
int size = 64; int size = 64;
int d = 2 * size; int d = 2 * size;
auto a = mx::random::uniform({d, d, d}); auto a = random::uniform({d, d, d});
shape = {8 * size, size, size}; shape = {8 * size, size, size};
TIMEM("3D contiguous", reshape_fn, a); TIMEM("3D contiguous", reshape_fn, a);
a = mx::transpose(a); a = transpose(a);
shape = {8 * size, size, size}; shape = {8 * size, size, size};
TIMEM("3D mx::transpose", reshape_fn, a); TIMEM("3D transpose", reshape_fn, a);
a = mx::transpose(a, {1, 2, 0}); a = transpose(a, {1, 2, 0});
shape = {8 * size, size, size}; shape = {8 * size, size, size};
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a); TIMEM("3D transpose dims 1 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d}); a = broadcast_to(random::uniform({d, d}), {d, d, d});
TIMEM("3D broadcast dim 0", reshape_fn, a); TIMEM("3D broadcast dim 0", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d}); a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
TIMEM("3D broadcast dim 1", reshape_fn, a); TIMEM("3D broadcast dim 1", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d}); a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
TIMEM("3D broadcast dim 2", reshape_fn, a); TIMEM("3D broadcast dim 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d}); a = broadcast_to(random::uniform({d}), {d, d, d});
TIMEM("3D broadcast dims 0, 1", reshape_fn, a); TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d}); a = broadcast_to(random::uniform({d, 1}), {d, d, d});
TIMEM("3D broadcast dims 0, 2", reshape_fn, a); TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d}); a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2", reshape_fn, a); TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d}); a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a); TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
} }
void time_irregular_astype_1D() { void time_irregular_astype_1D() {
auto device = mx::default_device(); auto device = default_device();
int size = 1000000; int size = 1000000;
int step = 2; int step = 2;
auto a = mx::random::uniform({size}); auto a = random::uniform({size});
a = slice(a, {0}, {size}, {step}); a = slice(a, {0}, {size}, {step});
TIMEM("1D strided", mx::astype, a, mx::int32, device); TIMEM("1D strided", astype, a, int32, device);
} }
void time_irregular_astype_2D() { void time_irregular_astype_2D() {
auto device = mx::default_device(); auto device = default_device();
int size = 2048; int size = 2048;
std::vector<int> shape = {size, size}; std::vector<int> shape = {size, size};
auto a = mx::random::uniform(shape); auto a = random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device); TIMEM("2D regular", astype, a, int32, device);
a = mx::transpose(a); a = transpose(a);
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device); TIMEM("2D transpose", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size}), shape); a = broadcast_to(random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device); TIMEM("2D broadcast dim 0", astype, a, int32, device);
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape); a = broadcast_to(random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device); TIMEM("2D broadcast dim 1", astype, a, int32, device);
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
if (argc > 1) { if (argc > 1) {
bool use_gpu = !strcmp(argv[1], "gpu"); bool use_gpu = !strcmp(argv[1], "gpu");
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu); set_default_device(use_gpu ? Device::gpu : Device::cpu);
} }
std::cout << "Benchmarks for " << mx::default_device() << std::endl; std::cout << "Benchmarks for " << default_device() << std::endl;
time_irregular_binary_ops_1D(); time_irregular_binary_ops_1D();
time_irregular_binary_ops_2D(); time_irregular_binary_ops_2D();
time_irregular_binary_ops_3D(); time_irregular_binary_ops_3D();

View File

@@ -3,20 +3,20 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
#include "time_utils.h" #include "time_utils.h"
namespace mx = mlx::core; using namespace mlx::core;
void time_creation_ops() { void time_creation_ops() {
int M = 2000; int M = 2000;
int N = 500; int N = 500;
auto shape = {M, N}; auto shape = {M, N};
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); }; auto full_fp32 = [&]() { return full(shape, 3.3f); };
TIME(full_fp32); TIME(full_fp32);
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); }; auto zeros_fp32 = [&]() { return zeros(shape, float32); };
TIME(zeros_fp32); TIME(zeros_fp32);
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); }; auto ones_fp32 = [&]() { return ones(shape, float32); };
TIME(ones_fp32); TIME(ones_fp32);
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); }; auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
TIME(arange_fp32); TIME(arange_fp32);
} }
@@ -24,196 +24,194 @@ void time_type_conversions() {
int M = 2000; int M = 2000;
int N = 500; int N = 500;
auto shape = {M, N}; auto shape = {M, N};
auto device = mx::default_device(); auto device = default_device();
auto a = mx::zeros(shape, mx::float32); auto a = zeros(shape, float32);
mx::eval(a); eval(a);
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device); TIMEM("float32 to int32", astype, a, int32, device);
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device); TIMEM("float32 to uint32", astype, a, uint32, device);
a = mx::zeros(shape, mx::int32); a = zeros(shape, int32);
mx::eval(a); eval(a);
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device); TIMEM("int32 to float32", astype, a, float32, device);
a = mx::zeros(shape, mx::bool_); a = zeros(shape, bool_);
mx::eval(a); eval(a);
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device); TIMEM("bool to float32", astype, a, float32, device);
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device); TIMEM("bool to int32", astype, a, int32, device);
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device); TIMEM("bool to uint32", astype, a, uint32, device);
} }
void time_random_generation() { void time_random_generation() {
int M = 2000; int M = 2000;
int N = 500; int N = 500;
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); }; auto uniform = [&]() { return random::uniform({M, N}, float32); };
TIME(uniform); TIME(uniform);
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); }; auto normal = [&]() { return random::normal({M, N}, float32); };
TIME(normal); TIME(normal);
} }
void time_unary_ops() { void time_unary_ops() {
int M = 2000; int M = 2000;
int N = 500; int N = 500;
auto device = mx::default_device(); auto device = default_device();
auto a = mx::random::normal({M, N}); auto a = random::normal({M, N});
mx::eval(a); eval(a);
TIME(mlx::core::abs, a, device); TIME(mlx::core::abs, a, device);
TIME(mx::negative, a, device); TIME(negative, a, device);
TIME(mx::sign, a, device); TIME(sign, a, device);
TIME(mx::square, a, device); TIME(square, a, device);
TIME(mlx::core::sqrt, a, device); TIME(mlx::core::sqrt, a, device);
TIME(mx::rsqrt, a, device); TIME(rsqrt, a, device);
TIME(mlx::core::exp, a, device); TIME(mlx::core::exp, a, device);
a = mx::random::uniform({M, N}); a = random::uniform({M, N});
TIME(mlx::core::log, a, device); TIME(mlx::core::log, a, device);
} }
void time_binary_ops() { void time_binary_ops() {
int M = 1000, N = 100, K = 10; int M = 1000, N = 100, K = 10;
auto condition = mx::random::randint(0, 2, {M, N, K}); auto condition = random::randint(0, 2, {M, N, K});
auto a = mx::random::uniform({M, N, K}); auto a = random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K}); auto b = random::uniform({M, N, K});
auto device = mx::default_device(); auto device = default_device();
mx::eval(a, b); eval(a, b);
TIME(mx::add, a, b, device); TIME(add, a, b, device);
TIME(mx::subtract, a, b, device); TIME(subtract, a, b, device);
TIME(mx::multiply, a, b, device); TIME(multiply, a, b, device);
TIME(mx::divide, a, b, device); TIME(divide, a, b, device);
TIME(mx::maximum, a, b, device); TIME(maximum, a, b, device);
TIME(mx::minimum, a, b, device); TIME(minimum, a, b, device);
TIME(mx::where, condition, a, b, device); TIME(where, condition, a, b, device);
condition = mx::array({true}); condition = array({true});
b = mx::random::uniform({1}); b = random::uniform({1});
mx::eval(b); eval(b);
TIMEM("scalar", mx::add, a, b, device); TIMEM("scalar", add, a, b, device);
TIMEM("vector-scalar", mx::subtract, a, b, device); TIMEM("vector-scalar", subtract, a, b, device);
TIMEM("scalar-vector", mx::subtract, b, a, device); TIMEM("scalar-vector", subtract, b, a, device);
TIMEM("scalar", mx::multiply, a, b, device); TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", mx::divide, a, b, device); TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", mx::divide, b, a, device); TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", mx::where, condition, a, b, device); TIMEM("scalar-vector", where, condition, a, b, device);
condition = mx::broadcast_to(mx::array({true}), {1000, 100}); condition = broadcast_to(array({true}), {1000, 100});
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100}); a = broadcast_to(random::uniform({1}), {1000, 100});
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100}); b = broadcast_to(random::uniform({1}), {1000, 100});
mx::eval(a, b); eval(a, b);
TIMEM("scalar-scalar broadcast", mx::add, a, b, device); TIMEM("scalar-scalar broadcast", add, a, b, device);
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device); TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device); TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device); TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device); TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
} }
void time_strided_ops() { void time_strided_ops() {
int M = 50, N = 50, O = 50, P = 50; int M = 50, N = 50, O = 50, P = 50;
auto a = mx::random::uniform({M, N, O, P}); auto a = random::uniform({M, N, O, P});
auto b = mx::random::uniform({M, N, O, P}); auto b = random::uniform({M, N, O, P});
auto device = mx::default_device(); auto device = default_device();
mx::eval(a, b); eval(a, b);
TIMEM("non-strided", mx::add, a, b, device); TIMEM("non-strided", add, a, b, device);
a = mx::transpose(a, {1, 0, 2, 3}); a = transpose(a, {1, 0, 2, 3});
b = mx::transpose(b, {3, 2, 0, 1}); b = transpose(b, {3, 2, 0, 1});
mx::eval(a, b); eval(a, b);
TIMEM("strided", mx::add, a, b, device); TIMEM("strided", add, a, b, device);
} }
void time_comparisons() { void time_comparisons() {
int M = 1000, N = 100, K = 10; int M = 1000, N = 100, K = 10;
auto a = mx::random::uniform({M, N, K}); auto a = random::uniform({M, N, K});
auto b = mx::random::uniform({M, N, K}); auto b = random::uniform({M, N, K});
auto device = mx::default_device(); auto device = default_device();
mx::eval(a, b); eval(a, b);
TIME(mx::equal, a, b, device); TIME(equal, a, b, device);
TIME(mx::greater, a, b, device); TIME(greater, a, b, device);
TIME(mx::greater_equal, a, b, device); TIME(greater_equal, a, b, device);
TIME(mx::less, a, b, device); TIME(less, a, b, device);
TIME(mx::less_equal, a, b, device); TIME(less_equal, a, b, device);
} }
void time_matvec() { void time_matvec() {
int M = 2000, N = 200; int M = 2000, N = 200;
auto a = mx::random::uniform({M, N}); auto a = random::uniform({M, N});
auto b = mx::random::uniform({N}); auto b = random::uniform({N});
auto c = mx::random::uniform({M}); auto c = random::uniform({M});
mx::eval(a, b, c); eval(a, b, c);
auto matvec = [&]() { return mx::matmul(a, b); }; auto matvec = [&]() { return matmul(a, b); };
TIME(matvec); TIME(matvec);
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); }; auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
TIME(matvec_transpose); TIME(matvec_transpose);
} }
void time_matmul() { void time_matmul() {
int M = 1000, N = 1000, K = 1000; int M = 1000, N = 1000, K = 1000;
auto a = mx::random::uniform({M, K}); auto a = random::uniform({M, K});
auto b = mx::random::uniform({K, N}); auto b = random::uniform({K, N});
auto device = mx::default_device(); auto device = default_device();
mx::eval(a, b); eval(a, b);
TIME(mx::matmul, a, b, device); TIME(matmul, a, b, device);
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); }; auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
TIME(transpose_matmul); TIME(transpose_matmul);
} }
void time_reductions() { void time_reductions() {
auto a = mx::random::normal({10000, 1000}); auto a = random::normal({10000, 1000});
mx::eval(a); eval(a);
auto sum_all = [&a]() { return mx::sum(a, false); }; auto sum_all = [&a]() { return sum(a, false); };
TIME(sum_all); TIME(sum_all);
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); }; auto sum_along_0 = [&a]() { return sum(a, 0, false); };
TIME(sum_along_0); TIME(sum_along_0);
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); }; auto sum_along_1 = [&a]() { return sum(a, 1, false); };
TIME(sum_along_1); TIME(sum_along_1);
auto prod_all = [&a]() { return mx::prod(a, false); }; auto prod_all = [&a]() { return prod(a, false); };
TIME(prod_all); TIME(prod_all);
auto all_true = [&a]() { return mx::all(a, false); }; auto all_true = [&a]() { return all(a, false); };
TIME(all_true); TIME(all_true);
auto all_along_0 = [&a]() { return mx::all(a, 0, false); }; auto all_along_0 = [&a]() { return all(a, 0, false); };
TIME(all_along_0); TIME(all_along_0);
auto all_along_1 = [&a]() { return mx::all(a, 1, false); }; auto all_along_1 = [&a]() { return all(a, 1, false); };
TIME(all_along_1); TIME(all_along_1);
auto any_true = [&a]() { return mx::any(a, false); }; auto any_true = [&a]() { return any(a, false); };
TIME(any_true); TIME(any_true);
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); }; auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
TIME(argmin_along_0); TIME(argmin_along_0);
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); }; auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
TIME(argmin_along_1); TIME(argmin_along_1);
} }
void time_gather_scatter() { void time_gather_scatter() {
auto a = mx::random::normal({1000, 768}); auto a = random::normal({1000, 768});
mx::eval(a); eval(a);
auto indices = mx::random::randint(0, 1000, {256}); auto indices = random::randint(0, 1000, {256});
mx::eval(indices); eval(indices);
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); }; auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
TIME(embedding_lookup); TIME(embedding_lookup);
indices = mx::random::randint(0, 768 * 1000, {256 * 768}); indices = random::randint(0, 768 * 1000, {256 * 768});
mx::eval(indices); eval(indices);
auto single_element_lookup = [&a, &indices]() { auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
return mx::take(a, indices);
};
TIME(single_element_lookup); TIME(single_element_lookup);
indices = mx::random::randint(0, 1000, {256}); indices = random::randint(0, 1000, {256});
auto updates = mx::random::normal({256, 1, 768}); auto updates = random::normal({256, 1, 768});
mx::eval(indices, updates); eval(indices, updates);
auto embedding_update = [&a, &indices, &updates]() { auto embedding_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0); return scatter(a, indices, updates, 0);
@@ -225,10 +223,10 @@ void time_gather_scatter() {
}; };
TIME(embedding_add); TIME(embedding_add);
a = mx::reshape(a, {-1}); a = reshape(a, {-1});
indices = mx::random::randint(0, 768 * 1000, {768 * 256}); indices = random::randint(0, 768 * 1000, {768 * 256});
updates = mx::random::normal({256 * 768, 1}); updates = random::normal({256 * 768, 1});
mx::eval(a, indices, updates); eval(a, indices, updates);
auto single_element_update = [&a, &indices, &updates]() { auto single_element_update = [&a, &indices, &updates]() {
return scatter(a, indices, updates, 0); return scatter(a, indices, updates, 0);
@@ -242,21 +240,21 @@ void time_gather_scatter() {
} }
void time_divmod() { void time_divmod() {
auto a = mx::random::normal({1000}); auto a = random::normal({1000});
auto b = mx::random::normal({1000}); auto b = random::normal({1000});
mx::eval({a, b}); eval({a, b});
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); }; auto divmod_fused = [&a, &b]() { return divmod(a, b); };
TIME(divmod_fused); TIME(divmod_fused);
auto divmod_separate = [&a, &b]() { auto divmod_separate = [&a, &b]() {
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)}; return std::vector<array>{floor_divide(a, b), remainder(a, b)};
}; };
TIME(divmod_separate); TIME(divmod_separate);
} }
int main() { int main() {
std::cout << "Benchmarks for " << mx::default_device() << std::endl; std::cout << "Benchmarks for " << default_device() << std::endl;
time_creation_ops(); time_creation_ops();
time_type_conversions(); time_type_conversions();
time_unary_ops(); time_unary_ops();

View File

@@ -144,13 +144,6 @@ def reduction(op, axis, x):
mx.eval(ys) mx.eval(ys)
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
def softmax(axis, x): def softmax(axis, x):
ys = [] ys = []
for i in range(100): for i in range(100):
@@ -512,8 +505,5 @@ if __name__ == "__main__":
elif args.benchmark == "selu": elif args.benchmark == "selu":
print(bench(selu, x)) print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else: else:
raise ValueError("Unknown benchmark") raise ValueError("Unknown benchmark")

View File

@@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import argparse import argparse
from time import time
import mlx.core as mx import mlx.core as mx
import torch import torch

View File

@@ -10,12 +10,7 @@ def layer_norm(x, w, b, eps):
x = x.astype(mx.float32) x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True) mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True) v = mx.var(x, -1, keepdims=True)
y = (x - mu) * mx.rsqrt(v + eps) return (x - mu) * mx.rsqrt(v + eps) * w + b
if w is not None:
y = y * w
if b is not None:
y = y + b
return y
def time_layer_norm(): def time_layer_norm():
@@ -41,28 +36,6 @@ def time_layer_norm():
time_fn(layer_norm_loop, mx.compile(g1), x, w, b) time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b) time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(layer_norm_loop, g1, x)
time_fn(layer_norm_loop, g2, x)
time_fn(layer_norm_loop, mx.compile(g1), x)
time_fn(layer_norm_loop, mx.compile(g2), x)
if __name__ == "__main__": if __name__ == "__main__":
time_layer_norm() time_layer_norm()

View File

@@ -9,10 +9,7 @@ def rms_norm(x, w, eps):
ot = x.dtype ot = x.dtype
x = x.astype(mx.float32) x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps) n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
y = (x * n).astype(ot) return (x * n).astype(ot) * w
if w is not None:
y = y * w
return y
def time_rms_norm(): def time_rms_norm():
@@ -37,27 +34,6 @@ def time_rms_norm():
time_fn(rms_norm_loop, mx.compile(g1), x, w) time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w) time_fn(rms_norm_loop, mx.compile(g2), x, w)
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x):
gx = x
for _ in range(32):
gx = g(gx, y)
return gx
time_fn(rms_norm_loop, g1, x)
time_fn(rms_norm_loop, g2, x)
time_fn(rms_norm_loop, mx.compile(g1), x)
time_fn(rms_norm_loop, mx.compile(g2), x)
if __name__ == "__main__": if __name__ == "__main__":
time_rms_norm() time_rms_norm()

View File

@@ -1,189 +1,62 @@
# Copyright © 2024 Apple Inc.
import argparse import argparse
import math import math
import os
import subprocess
import time
import mlx.core as mx import mlx.core as mx
import numpy as np from time_utils import time_fn
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) MAX_SEQ = 300
device_name = device_name.decode("utf-8").strip("\n") START_SEQ = 100
SEQ_INCREMENT = 50
N_warmup = 5
N_iter_bench = 40
N_iter_func = 8
def bench(f, *args): def time_self_attention_primitives():
for i in range(N_warmup): mx.random.seed(3)
f(*args) B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
s = time.perf_counter_ns() def sdpa_primitives(qs, ks, vs, alpha):
for i in range(N_iter_bench): s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
f(*args) p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
e = time.perf_counter_ns() o = p @ vs
return (e - s) * 1e-9 return o
time_fn(sdpa_primitives, q, k, v, scale)
def mlx_sdpa_fused_inner(q, k, v, scale): def time_self_attention_sdpa():
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None) mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_fused(qs, ks, vs, alpha):
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
return o
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): time_fn(sdpa_fused, q, k, v, scale)
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
else:
scores = mx.softmax(scores, axis=-1)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def mlx_spda_unfused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def mlx_spda_fused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
)
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
scale = math.sqrt(1.0 / head_dim)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks") parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
dtypes = ("float16", "float32")[:1] time_self_attention_sdpa()
transposes = (False,) time_self_attention_primitives()
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -4,92 +4,46 @@ import math
import mlx.core as mx import mlx.core as mx
from time_utils import time_fn from time_utils import time_fn
L = 16384 L = 1024
H = 32 H = 32
H_k = H // 4 H_k = 32 // 4
D = 128 D = 128
V = 128
dtype = mx.float16
loops = 10
def upproject(x, w): def attention(q, k, v):
if w is None: B, Hq, L, D = q.shape
return x _, Hk, S, _ = k.shape
else: q = q.reshape(B, Hk, Hq // Hk, L, D)
return x @ w.T k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)
def attention(q, k, v, mask=None, w=None): def sdpa(q, k, v):
def _sdpa(q, k, v): return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
_, _, _, V = v.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
if mask is not None:
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
s = mx.where(m, s, mx.finfo(s.dtype).min)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, V)
for i in range(loops):
q = _sdpa(q, k, v)
q = upproject(q, w)
return q
def sdpa(q, k, v, mask=None, w=None):
for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
q = upproject(q, w)
return q
def time_self_attention_primitives(): def time_self_attention_primitives():
mx.random.seed(3) mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) v = mx.random.uniform(shape=(1, H_k, L, D))
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None mx.eval(q, k, v)
mx.eval(q, k, v, w) time_fn(attention, q, k, v)
time_fn(attention, q, k, v, w=w)
def time_self_attention_sdpa(): def time_self_attention_sdpa():
mx.random.seed(3) mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype) q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype) k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype) v = mx.random.uniform(shape=(1, H_k, L, D))
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None mx.eval(q, k, v)
mx.eval(q, k, v, w) time_fn(sdpa, q, k, v)
time_fn(sdpa, q, k, v, w=w)
def time_self_attention_sdpa_with_mask():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mask = mx.full((L,), True)
mask[L // 2 :] = False
mx.eval(q, k, v, mask, w)
def sdpa_mask(*args):
return sdpa(*args, mask=mask, w=w)
def attention_mask(*args):
return attention(*args, mask=mask, w=w)
time_fn(attention_mask, q, k, v)
time_fn(sdpa_mask, q, k, v)
if __name__ == "__main__": if __name__ == "__main__":
time_self_attention_sdpa() time_self_attention_sdpa()
time_self_attention_primitives() time_self_attention_primitives()
time_self_attention_sdpa_with_mask()

View File

@@ -1,55 +0,0 @@
import time
import mlx.core as mx
rank = mx.distributed.init().rank()
def timeit(fn, a):
# warmup
for _ in range(5):
mx.eval(fn(a))
its = 10
tic = time.perf_counter()
for _ in range(its):
mx.eval(fn(a))
toc = time.perf_counter()
ms = 1000 * (toc - tic) / its
return ms
def all_reduce_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_sum(x)
x = x - 1
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
def all_gather_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_gather(x)[0]
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All gather: time per iteration {ms:.6f} (ms)")
if __name__ == "__main__":
all_reduce_benchmark()
all_gather_benchmark()

View File

@@ -1,7 +1,5 @@
include(CMakeParseArguments) include(CMakeParseArguments)
# clang format off
#
# ############################################################################## # ##############################################################################
# Build metal library # Build metal library
# #
@@ -13,8 +11,6 @@ include(CMakeParseArguments)
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers) # files (like headers)
# #
# clang format on
macro(mlx_build_metallib) macro(mlx_build_metallib)
# Parse args # Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
@@ -25,7 +21,7 @@ macro(mlx_build_metallib)
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib") set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options # Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
# Prepare metallib build command # Prepare metallib build command
add_custom_command( add_custom_command(

View File

@@ -22,12 +22,12 @@ You can do that in MLX directly:
This function performs that operation while leaving the implementation and This function performs that operation while leaving the implementation and
function transformations to MLX. function transformations to MLX.
However, you may want to customize the underlying implementation, perhaps to However you may need to customize the underlying implementation, perhaps to
make it faster. In this tutorial we will go through adding custom extensions. make it faster or for custom differentiation. In this tutorial we will go
It will cover: through adding custom extensions. It will cover:
* The structure of the MLX library. * The structure of the MLX library.
* Implementing a CPU operation. * Implementing a CPU operation that redirects to Accelerate_ when appropriate.
* Implementing a GPU operation using metal. * Implementing a GPU operation using metal.
* Adding the ``vjp`` and ``jvp`` function transformation. * Adding the ``vjp`` and ``jvp`` function transformation.
* Building a custom extension and binding it to python. * Building a custom extension and binding it to python.
@@ -45,7 +45,7 @@ Operations
Operations are the front-end functions that operate on arrays. They are defined Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them. in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in ``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++: C++:
@@ -55,7 +55,7 @@ C++:
* Scale and sum two vectors element-wise * Scale and sum two vectors element-wise
* z = alpha * x + beta * y * z = alpha * x + beta * y
* *
* Use NumPy-style broadcasting between x and y * Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed * Inputs are upcasted to floats if needed
**/ **/
array axpby( array axpby(
@@ -66,7 +66,7 @@ C++:
StreamOrDevice s = {} // Stream on which to schedule the operation StreamOrDevice s = {} // Stream on which to schedule the operation
); );
The simplest way to implement this is with existing operations: The simplest way to this operation is in terms of existing operations:
.. code-block:: C++ .. code-block:: C++
@@ -153,6 +153,9 @@ more concrete:
private: private:
float alpha_; float alpha_;
float beta_; float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, array& out);
}; };
The :class:`Axpby` class derives from the base :class:`Primitive` class. The The :class:`Axpby` class derives from the base :class:`Primitive` class. The
@@ -185,7 +188,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
auto promoted_dtype = promote_types(x.dtype(), y.dtype()); auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y // Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32) auto out_dtype = is_floating_point(promoted_dtype)
? promoted_dtype ? promoted_dtype
: promote_types(promoted_dtype, float32); : promote_types(promoted_dtype, float32);
@@ -231,59 +234,49 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
Implementing the CPU Back-end Implementing the CPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by implementing :meth:`Axpby::eval_cpu`. Let's start by implementing a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
The method will go over each element of the output array, find the Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation corresponding input elements of ``x`` and ``y`` and perform the operation
point-wise. This is captured in the templated function :meth:`axpby_impl`. point-wise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++ .. code-block:: C++
template <typename T> template <typename T>
void axpby_impl( void axpby_impl(
const mx::array& x, const array& x,
const mx::array& y, const array& y,
mx::array& out, array& out,
float alpha_, float alpha_,
float beta_, float beta_) {
mx::Stream stream) { // We only allocate memory when we are ready to fill the output
// Allocate the output with `malloc_or_wait` which synchronously allocates // malloc_or_wait synchronously allocates available memory
// memory, potentially waiting if the system is under memory pressure // There may be a wait executed here if the allocation is requested
out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); // under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get the CPU command encoder and register input and output arrays // Collect input and output data pointers
auto& encoder = mx::cpu::get_command_encoder(stream); const T* x_ptr = x.data<T>();
encoder.set_input_array(x); const T* y_ptr = y.data<T>();
encoder.set_input_array(y); T* out_ptr = out.data<T>();
encoder.set_output_array(out);
// Launch the CPU kernel // Cast alpha and beta to the relevant types
encoder.dispatch([x_ptr = x.data<T>(), T alpha = static_cast<T>(alpha_);
y_ptr = y.data<T>(), T beta = static_cast<T>(beta_);
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Cast alpha and beta to the relevant types // Do the element-wise operation for each output
T alpha = static_cast<T>(alpha_); for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
T beta = static_cast<T>(beta_); // Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// Do the element-wise operation for each output // We allocate the output to be contiguous and regularly strided
for (size_t out_idx = 0; out_idx < size; out_idx++) { // (defaults to row major) and hence it doesn't need additional mapping
// Map linear indices to offsets in x and y out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides); }
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides); }
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
});
}
Our implementation should work for all incoming floating point arrays. Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
@@ -291,32 +284,112 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
.. code-block:: C++ .. code-block:: C++
void Axpby::eval_cpu( /** Fall back implementation for evaluation on CPU */
const std::vector<mx::array>& inputs, void Axpby::eval(
std::vector<mx::array>& outputs) { const std::vector<array>& inputs,
auto& x = inputs[0]; const std::vector<array>& outputs) {
auto& y = inputs[1]; auto& x = inputs[0];
auto& out = outputs[0]; auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype // Dispatch to the correct dtype
if (out.dtype() == mx::float32) { if (out.dtype() == float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_, stream()); return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::float16) { } else if (out.dtype() == float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream()); return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::bfloat16) { } else if (out.dtype() == bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream()); return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::complex64) { } else if (out.dtype() == complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream()); return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"Axpby is only supported for floating point types."); "[Axpby] Only supports floating point types.");
} }
}
This is good as a fallback implementation. We can use the ``axpby`` routine
provided by the Accelerate_ framework for a faster implementation in certain
cases:
#. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only use it for ``float32`` types.
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
elements have fixed strides between them. We only direct to Accelerate
if both ``x`` and ``y`` are row contiguous or column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
MLX expects to write the output to a new array. We must copy the elements
of ``y`` into the output and use that as an input to ``axpby``.
Let's write an implementation that uses Accelerate in the right conditions.
It allocates data for the output, copies ``y`` into it, and then calls the
:func:`catlas_saxpby` from accelerate.
.. code-block:: C++
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
For inputs that do not fit the criteria for accelerate, we fall back to
:meth:`Axpby::eval`. With this in mind, let's finish our
:meth:`Axpby::eval_cpu`.
.. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common back-end if specializations are not available
eval(inputs, outputs);
} }
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here. primitive here and enjoy the speed-ups you get from the Accelerate library.
Implementing the GPU Back-end Implementing the GPU Back-end
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -347,8 +420,8 @@ element in the output.
constant const float& alpha [[buffer(3)]], constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]], constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]], constant const int* shape [[buffer(5)]],
constant const int64_t* x_strides [[buffer(6)]], constant const size_t* x_strides [[buffer(6)]],
constant const int64_t* y_strides [[buffer(7)]], constant const size_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]], constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
// Convert linear indices to offsets in array // Convert linear indices to offsets in array
@@ -365,10 +438,24 @@ each instantiation a unique host name so we can identify it.
.. code-block:: C++ .. code-block:: C++
instantiate_kernel("axpby_general_float32", axpby_general, float) #define instantiate_axpby(type_name, type) \
instantiate_kernel("axpby_general_float16", axpby_general, float16_t) template [[host_name("axpby_general_" #type_name)]] \
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t) [[kernel]] void axpby_general<type>( \
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t) device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
The logic to determine the kernel, set the inputs, resolve the grid dimensions, The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
@@ -407,7 +494,7 @@ below.
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to // Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal // those in the kernel declaration at axpby.metal
@@ -422,14 +509,14 @@ below.
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode alpha and beta // Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3); compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder.set_bytes(beta_, 4); compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim // Encode shape, strides and ndim
compute_encoder.set_vector_bytes(x.shape(), 5); compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder.set_vector_bytes(x.strides(), 6); compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder.set_bytes(y.strides(), 7); compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder.set_bytes(ndim, 8); compute_encoder->setBytes(&ndim, sizeof(int), 8);
// We launch 1 thread for each input and make sure that the number of // We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed // threads in any given threadgroup is not higher than the max allowed
@@ -443,7 +530,7 @@ below.
// Launch the grid with the given number of threads divided among // Launch the grid with the given number of threads divided among
// the given threadgroups // the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
We can now call the :meth:`axpby` operation on both the CPU and the GPU! We can now call the :meth:`axpby` operation on both the CPU and the GPU!
@@ -751,7 +838,7 @@ Results
^^^^^^^ ^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we first defined. with the naive :meth:`simple_axpby` we first defined on the CPU.
.. code-block:: python .. code-block:: python
@@ -759,11 +846,13 @@ with the naive :meth:`simple_axpby` we first defined.
from mlx_sample_extensions import axpby from mlx_sample_extensions import axpby
import time import time
mx.set_default_device(mx.cpu)
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array: def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y return alpha * x + beta * y
M = 4096 M = 256
N = 4096 N = 512
x = mx.random.normal((M, N)) x = mx.random.normal((M, N))
y = mx.random.normal((M, N)) y = mx.random.normal((M, N))
@@ -774,24 +863,24 @@ with the naive :meth:`simple_axpby` we first defined.
def bench(f): def bench(f):
# Warm up # Warm up
for i in range(5): for i in range(100):
z = f(x, y, alpha, beta) z = f(x, y, alpha, beta)
mx.eval(z) mx.eval(z)
# Timed run # Timed run
s = time.time() s = time.time()
for i in range(100): for i in range(5000):
z = f(x, y, alpha, beta) z = f(x, y, alpha, beta)
mx.eval(z) mx.eval(z)
e = time.time() e = time.time()
return 1000 * (e - s) / 100 return e - s
simple_time = bench(simple_axpby) simple_time = bench(simple_axpby)
custom_time = bench(axpby) custom_time = bench(axpby)
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms") print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
modest improvements right away! modest improvements right away!
This operation is now good to be used to build other operations, in This operation is now good to be used to build other operations, in

View File

@@ -1,121 +0,0 @@
.. _mlx_in_cpp:
Using MLX in C++
================
You can use MLX in a C++ project with CMake.
.. note::
This guide is based one the following `example using MLX in C++
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
First install MLX:
.. code-block:: bash
pip install -U mlx
You can also install the MLX Python package from source or just the C++
library. For more information see the :ref:`documentation on installing MLX
<build_and_install>`.
Next make an example program in ``example.cpp``:
.. code-block:: C++
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}
The next step is to setup a CMake file in ``CMakeLists.txt``:
.. code-block:: cmake
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Depending on how you installed MLX, you may need to tell CMake where to
find it.
If you installed MLX with Python, then add the following to the CMake file:
.. code-block:: cmake
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
If you installed the MLX C++ package to a system path, then CMake should be
able to find it. If you installed it to a non-standard location or CMake can't
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
.. code-block:: cmake
set(MLX_ROOT "/path/to/mlx/")
Next, instruct CMake to find MLX:
.. code-block:: cmake
find_package(MLX CONFIG REQUIRED)
Finally, add the ``example.cpp`` program as an executable and link MLX.
.. code-block:: cmake
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)
You can build the example with:
.. code-block:: bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
And run it with:
.. code-block:: bash
./build/example
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
.. list-table:: Package Variables
:widths: 20 20
:header-rows: 1
* - Variable
- Description
* - MLX_FOUND
- ``True`` if MLX is found
* - MLX_INCLUDE_DIRS
- Include directory
* - MLX_LIBRARIES
- Libraries to link against
* - MLX_CXX_FLAGS
- Additional compiler flags
* - MLX_BUILD_ACCELERATE
- ``True`` if MLX was built with Accelerate
* - MLX_BUILD_METAL
- ``True`` if MLX was built with Metal

View File

@@ -45,7 +45,6 @@ are the CPU and GPU.
usage/numpy usage/numpy
usage/distributed usage/distributed
usage/using_streams usage/using_streams
usage/export
.. toctree:: .. toctree::
:caption: Examples :caption: Examples
@@ -62,7 +61,6 @@ are the CPU and GPU.
python/array python/array
python/data_types python/data_types
python/devices_and_streams python/devices_and_streams
python/export
python/ops python/ops
python/random python/random
python/transforms python/transforms
@@ -88,4 +86,3 @@ are the CPU and GPU.
dev/extensions dev/extensions
dev/metal_debugger dev/metal_debugger
dev/custom_metal_kernels dev/custom_metal_kernels
dev/mlx_in_cpp

View File

@@ -1,5 +1,3 @@
.. _build_and_install:
Build and Install Build and Install
================= =================
@@ -55,7 +53,7 @@ Build Requirements
^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^
- A C++ compiler with C++17 support (e.g. Clang >= 5.0) - A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make`` - `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- Xcode >= 15.0 and macOS SDK >= 14.0 - Xcode >= 15.0 and macOS SDK >= 14.0
.. note:: .. note::
@@ -211,7 +209,7 @@ Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots. Metal kernel cache persists accross reboots.
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@@ -51,20 +51,11 @@ The default floating point type is ``float32`` and the default integer type is
* - ``float32`` * - ``float32``
- 4 - 4
- 32-bit float - 32-bit float
* - ``float64``
- 4
- 64-bit double
* - ``complex64`` * - ``complex64``
- 8 - 8
- 64-bit complex float - 64-bit complex float
.. note::
Arrays with type ``float64`` only work with CPU operations. Using
``float64`` arrays on the GPU will result in an exception.
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category. ``dtype`` (or category) is a subtype of another category.
@@ -75,4 +66,3 @@ documentation for more information. Use :func:`issubdtype` to determine if one
Dtype Dtype
DtypeCategory DtypeCategory
issubdtype issubdtype
finfo

View File

@@ -1,14 +0,0 @@
.. _export:
Export Functions
================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
export_function
import_function
exporter
export_to_dot

View File

@@ -12,4 +12,5 @@ Fast
layer_norm layer_norm
rope rope
scaled_dot_product_attention scaled_dot_product_attention
affine_quantize
metal_kernel metal_kernel

View File

@@ -5,8 +5,8 @@ Linear Algebra
.. currentmodule:: mlx.core.linalg .. currentmodule:: mlx.core.linalg
.. autosummary:: .. autosummary::
:toctree: _autosummary :toctree: _autosummary
inv inv
tri_inv tri_inv
@@ -18,7 +18,3 @@ Linear Algebra
svd svd
eigvalsh eigvalsh
eigh eigh
lu
lu_factor
solve
solve_triangular

View File

@@ -174,7 +174,6 @@ In detail:
value_and_grad value_and_grad
quantize quantize
average_gradients
.. toctree:: .. toctree::

View File

@@ -12,7 +12,6 @@ Layers
ALiBi ALiBi
AvgPool1d AvgPool1d
AvgPool2d AvgPool2d
AvgPool3d
BatchNorm BatchNorm
CELU CELU
Conv1d Conv1d
@@ -42,7 +41,6 @@ Layers
LSTM LSTM
MaxPool1d MaxPool1d
MaxPool2d MaxPool2d
MaxPool3d
Mish Mish
MultiHeadAttention MultiHeadAttention
PReLU PReLU

View File

@@ -32,7 +32,6 @@ Operations
atleast_2d atleast_2d
atleast_3d atleast_3d
bitwise_and bitwise_and
bitwise_invert
bitwise_or bitwise_or
bitwise_xor bitwise_xor
block_masked_mm block_masked_mm
@@ -90,7 +89,6 @@ Operations
isneginf isneginf
isposinf isposinf
issubdtype issubdtype
kron
left_shift left_shift
less less
less_equal less_equal
@@ -146,8 +144,6 @@ Operations
sign sign
sin sin
sinh sinh
slice
slice_update
softmax softmax
sort sort
split split
@@ -172,7 +168,6 @@ Operations
tri tri
tril tril
triu triu
unflatten
var var
view view
where where

View File

@@ -421,77 +421,3 @@ the most opportunity to optimize the computation graph:
# Compiling the outer function is good to do as it will likely # Compiling the outer function is good to do as it will likely
# be faster even though the inner functions are compiled # be faster even though the inner functions are compiled
fun = mx.compile(outer) fun = mx.compile(outer)
.. _shapeless_compile:
Shapeless Compilation
---------------------
When the shape of an input to a compiled function changes, the function is
recompiled. You can compile a function once and run it on inputs with
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
case changes to the shapes of the inputs do not cause the function to be
recompiled.
.. code-block:: python
def fun(x, y):
return mx.abs(x + y)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.array(1.0)
y = mx.array(-2.0)
# Firt call compiles the function
print(compiled_fun(x, y))
# Second call with different shapes
# does not recompile the function
x = mx.array([1.0, -6.0])
y = mx.array([-2.0, 3.0])
print(compiled_fun(x, y))
Use shapeless compilations carefully. Since compilation is not triggered when
shapes change, any graphs which are conditional on the input shapes will not
work as expected. Shape-dependent computations are common and sometimes subtle
to detect. For example:
.. code-block:: python
def fun(x):
return x.reshape(x.shape[0] * x.shape[1], -1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Error, can't reshape (5, 5, 3) to (6, -1)
out = compiled_fun(x)
The second call to the ``compiled_fun`` fails because of the call to
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
.. code-block:: python
def fun(x):
return x.flatten(0, 1)
compiled_fun = mx.compile(fun, shapeless=True)
x = mx.random.uniform(shape=(2, 3, 4))
out = compiled_fun(x)
x = mx.random.uniform(shape=(5, 5, 3))
# Ok
out = compiled_fun(x)

View File

@@ -5,27 +5,21 @@ Distributed Communication
.. currentmodule:: mlx.core.distributed .. currentmodule:: mlx.core.distributed
MLX supports distributed communication operations that allow the computational cost MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
of training or inference to be shared across many physical machines. At the provide distributed communication operations that allow the computational cost
moment we support two different communication backends: of training or inference to be shared across many physical machines. You can
see a list of the supported operations in the :ref:`API docs<distributed>`.
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets and should be
faster for thunderbolt connections.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
.. note:: .. note::
Some operations may not be supported or not as fast as they should be. A lot of operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX. best way to do distributed computing on Macs using MLX.
Getting Started Getting Started
--------------- ---------------
A distributed program in MLX is as simple as: MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. The minimal distributed program in MLX is as simple as:
.. code:: python .. code:: python
@@ -36,79 +30,74 @@ A distributed program in MLX is as simple as:
print(world.rank(), x) print(world.rank(), x)
The program above sums the array ``mx.ones(10)`` across all The program above sums the array ``mx.ones(10)`` across all
distributed processes. However, when this script is run with ``python`` only distributed processes. If simply run with ``python``, however, only one
one process is launched and no distributed communication takes place. Namely, process is launched and no distributed communication takes place.
all operations in ``mx.distributed`` are noops when the distributed group has a
size of one. This property allows us to avoid code that checks if we are in a
distributed setting similar to the one below:
.. code:: python To launch the program in distributed mode we need to use ``mpirun`` or
``mpiexec`` depending on the MPI installation. The simplest possible way is the
import mlx.core as mx following:
x = ...
world = mx.distributed.init()
# No need for the check we can simply do x = mx.distributed.all_sum(x)
if world.size() > 1:
x = mx.distributed.all_sum(x)
Running Distributed Programs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
Continuing with our initial example we can run it on localhost with 4 processes using
.. code:: shell .. code:: shell
$ mlx.launch -n 4 my_script.py $ mpirun -np 2 python test.py
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) 1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32) 0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
We can also run it on some remote hosts by providing their IPs (provided that The above launches two processes on the same (local) machine and we can see
the script exists on all hosts and they are reachable by ssh) both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
print 4 etc.
Installing MPI
---------------
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell .. code:: shell
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py $ conda install openmpi
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
Consult the dedicated :doc:`usage guide<launching_distributed>` for more Installing with Homebrew may require specifying the location of ``libmpi.dyld``
information on using ``mlx.launch``. so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
Selecting Backend .. code:: shell
^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they Setting up Remote Hosts
both fail then a singleton group is created. -----------------------
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
full path to force all machines to use a specific path.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
.. note:: .. note::
After a distributed backend is successfully initialized :func:`init` will For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
return **the same backend** if called without arguments or with backend set to the hostname passed to ssh if the current hostname matches ``*.bar.com``.
``any``.
The following examples aim to clarify the backend initialization logic in MLX: An easy way to pass the host names to MPI is using a host file. A host file
looks like the following, where ``host1`` and ``host2`` should be the fully
qualified domain names or IPs for these hosts.
.. code:: python .. code::
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend host1 slots=1
world = mx.distributed.init(backend="mpi") host2 slots=1
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
# Case 2: Initialize any backend When using MLX, it is very likely that you want to use 1 slot per host, ie one
world = mx.distributed.init(backend="any") # equivalent to no arguments process per host. The hostfile also needs to contain the current
world2 = mx.distributed.init() # same as above host if you want to run on the local host. Passing the host file to
``mpirun`` is simply done using the ``--hostfile`` command line argument.
# Case 3: Initialize both backends at the same time
world_mpi = mx.distributed.init(backend="mpi")
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first!
Training Example Training Example
---------------- ----------------
@@ -152,13 +141,12 @@ everything else remaining the same.
from mlx.utils import tree_map from mlx.utils import tree_map
def all_reduce_grads(grads): def all_reduce_grads(grads):
N = mx.distributed.init().size() N = mx.distributed.init()
if N == 1: if N == 1:
return grads return grads
return tree_map( return tree_map(
lambda x: mx.distributed.all_sum(x) / N, lambda x: mx.distributed.all_sum(x) / N,
grads grads)
)
def step(model, x, y): def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y) loss, grads = loss_grad_fn(model, x, y)
@@ -166,179 +154,13 @@ everything else remaining the same.
optimizer.update(model, grads) optimizer.update(model, grads)
return loss return loss
Utilizing ``nn.average_gradients`` Tuning All Reduce
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -----------------
Although the code example above works correctly; it performs one communication We are working on improving the performance of all reduce on MLX but for now
per gradient. It is significantly more efficient to aggregate several gradients the two main things one can do to extract the most out of distributed training with MLX are:
together and perform fewer communication steps.
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks 1. Perform a few large reductions instead of many small ones to improve
almost identical to the example above: bandwidth and latency
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
.. code:: python connections between each host to improve bandwidth
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mlx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
Getting Started with MPI
------------------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. Launching distributed MLX programs that use MPI can be done with
``mpirun`` as expected. However, in the following examples we will be using
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
library.
The simplest possible usage is the following which, assuming the minimal
example in the beginning of this page, should result in:
.. code:: shell
$ mlx.launch --backend mpi -n 2 test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
print 4 etc.
Installing MPI
^^^^^^^^^^^^^^
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install conda-forge::openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
$ # or simply
$ mlx.launch -n 2 test.py
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
Tuning MPI All Reduce
^^^^^^^^^^^^^^^^^^^^^
.. note::
For faster all reduce consider using the ring backend either with Thunderbolt
connections or over Ethernet.
Configure MPI to use N tcp connections between each host to improve bandwidth
by passing ``--mca btl_tcp_links N``.
Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use.
Getting Started with Ring
-------------------------
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result :func:`send` and :func:`recv` with
arbitrary sender and receiver is not supported in the ring backend.
Defining a Ring
^^^^^^^^^^^^^^^
The easiest way to define and use a ring is via a JSON hostfile and the
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
rank 0, ``hostname2`` rank 1 etc.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
]
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
connection from ``123.123.123.4`` and so on and so forth.
Thunderbolt Ring
^^^^^^^^^^^^^^^^
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
To validate your connection without configuring anything
``mlx.distributed_config`` can also plot the ring using DOT format.
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
dot -Tpng ring.dot >ring.png
open ring.png
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
corresponding to that cable in nodes ``i`` and ``i + 1``.
* Set up a unique subnetwork connecting the two nodes for the corresponding
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.

View File

@@ -1,288 +0,0 @@
.. _export_usage:
Exporting Functions
===================
.. currentmodule:: mlx.core
MLX has an API to export and import functions to and from a file. This lets you
run computations written in one MLX front-end (e.g. Python) in another MLX
front-end (e.g. C++).
This guide walks through the basics of the MLX export API with some examples.
To see the full list of functions check-out the :ref:`API documentation
<export>`.
Basics of Exporting
-------------------
Let's start with a simple example:
.. code-block:: python
def fun(x, y):
return x + y
x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("add.mlxfn", fun, x, y)
To export a function, provide sample input arrays that the function
can be called with. The data doesn't matter, but the shapes and types of the
arrays do. In the above example we exported ``fun`` with two ``float32``
scalar arrays. We can then import the function and run it:
.. code-block:: python
add_fun = mx.import_function("add.mlxfn")
out, = add_fun(mx.array(1.0), mx.array(2.0))
# Prints: array(3, dtype=float32)
print(out)
out, = add_fun(mx.array(1.0), mx.array(3.0))
# Prints: array(4, dtype=float32)
print(out)
# Raises an exception
add_fun(mx.array(1), mx.array(3.0))
# Raises an exception
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
Notice the third and fourth calls to ``add_fun`` raise exceptions because the
shapes and types of the inputs are different than the shapes and types of the
example inputs we exported the function with.
Also notice that even though the original ``fun`` returns a single output
array, the imported function always returns a tuple of one or more arrays.
The inputs to :func:`export_function` and to an imported function can be
specified as variable positional arguments or as a tuple of arrays:
.. code-block:: python
def fun(x, y):
return x + y
x = mx.array(1.0)
y = mx.array(1.0)
# Both arguments to fun are positional
mx.export_function("add.mlxfn", fun, x, y)
# Same as above
mx.export_function("add.mlxfn", fun, (x, y))
imported_fun = mx.import_function("add.mlxfn")
# Ok
out, = imported_fun(x, y)
# Also ok
out, = imported_fun((x, y))
You can pass example inputs to functions as positional or keyword arguments. If
you use keyword arguments to export the function, then you have to use the same
keyword arguments when calling the imported function.
.. code-block:: python
def fun(x, y):
return x + y
# One argument to fun is positional, the other is a kwarg
mx.export_function("add.mlxfn", fun, x, y=y)
imported_fun = mx.import_function("add.mlxfn")
# Ok
out, = imported_fun(x, y=y)
# Also ok
out, = imported_fun((x,), {"y": y})
# Raises since the keyword argument is missing
out, = imported_fun(x, y)
# Raises since the keyword argument has the wrong key
out, = imported_fun(x, z=y)
Exporting Modules
-----------------
An :obj:`mlx.nn.Module` can be exported with or without the parameters included
in the exported function. Here's an example:
.. code-block:: python
model = nn.Linear(4, 4)
mx.eval(model.parameters())
def call(x):
return model(x)
mx.export_function("model.mlxfn", call, mx.zeros(4))
In the above example, the :obj:`mlx.nn.Linear` module is exported. Its
parameters are also saved to the ``model.mlxfn`` file.
.. note::
For enclosed arrays inside an exported function, be extra careful to ensure
they are evaluated. The computation graph that gets exported will include
the computation that produces enclosed inputs.
If the above example was missing ``mx.eval(model.parameters()``, the
exported function would include the random initialization of the
:obj:`mlx.nn.Module` parameters.
If you only want to export the ``Module.__call__`` function without the
parameters, pass them as inputs to the ``call`` wrapper:
.. code-block:: python
model = nn.Linear(4, 4)
mx.eval(model.parameters())
def call(x, **params):
# Set the model's parameters to the input parameters
model.update(tree_unflatten(list(params.items())))
return model(x)
params = dict(tree_flatten(model.parameters()))
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
Shapeless Exports
-----------------
Just like :func:`compile`, functions can also be exported for dynamically shaped
inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter`
to export a function which can be used for inputs with variable shapes:
.. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn")
# Ok
out, = imported_abs(mx.array(-1.0))
# Also ok
out, = imported_abs(mx.array([-1.0, -2.0]))
With ``shapeless=False`` (which is the default), the second call to
``imported_abs`` would raise an exception with a shape mismatch.
Shapeless exporting works the same as shapeless compilation and should be
used carefully. See the :ref:`documentation on shapeless compilation
<shapeless_compile>` for more information.
Exporting Multiple Traces
-------------------------
In some cases, functions build different computation graphs for different
input arguments. A simple way to manage this is to export to a new file with
each set of inputs. This is a fine option in many cases. But it can be
suboptimal if the exported functions have a large amount of duplicate constant
data (for example the parameters of a :obj:`mlx.nn.Module`).
The export API in MLX lets you export multiple traces of the same function to
a single file by creating an exporting context manager with :func:`exporter`:
.. code-block:: python
def fun(x, y=None):
constant = mx.array(3.0)
if y is not None:
x += y
return x + constant
with mx.exporter("fun.mlxfn", fun) as exporter:
exporter(mx.array(1.0))
exporter(mx.array(1.0), y=mx.array(0.0))
imported_function = mx.import_function("fun.mlxfn")
# Call the function with y=None
out, = imported_function(mx.array(1.0))
print(out)
# Call the function with y specified
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
print(out)
In the above example the function constant data, (i.e. ``constant``), is only
saved once.
Transformations with Imported Functions
---------------------------------------
Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work
on imported functions just like regular Python functions:
.. code-block:: python
def fun(x):
return mx.sin(x)
x = mx.array(0.0)
mx.export_function("sine.mlxfn", fun, x)
imported_fun = mx.import_function("sine.mlxfn")
# Take the derivative of the imported function
dfdx = mx.grad(lambda x: imported_fun(x)[0])
# Prints: array(1, dtype=float32)
print(dfdx(x))
# Compile the imported function
mx.compile(imported_fun)
# Prints: array(0, dtype=float32)
print(compiled_fun(x)[0])
Importing Functions in C++
--------------------------
Importing and running functions in C++ is basically the same as importing and
running them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to
setup a simple C++ project that uses MLX as a library.
Next, export a simple function from Python:
.. code-block:: python
def fun(x, y):
return mx.exp(x + y)
x = mx.array(1.0)
y = mx.array(1.0)
mx.export_function("fun.mlxfn", fun, x, y)
Import and run the function in C++ with only a few lines of code:
.. code-block:: c++
auto fun = mx::import_function("fun.mlxfn");
auto inputs = {mx::array(1.0), mx::array(1.0)};
auto outputs = fun(inputs);
// Prints: array(2, dtype=float32)
std::cout << outputs[0] << std::endl;
Imported functions can be transformed in C++ just like in Python. Use
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
mx::array>`` for keyword arguments when calling imported functions in C++.
More Examples
-------------
Here are a few more complete examples exporting more complex functions from
Python and importing and running them in C++:
* `Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>`_

View File

@@ -184,8 +184,8 @@ Let's time these two different versions:
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100)) print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100)) print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
vectorized version takes only ``0.024`` seconds, more than 200 times faster. vectorized version takes only ``0.025`` seconds, more than ten times faster.
Of course, this operation is quite contrived. A better approach is to simply do Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy. ``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

View File

@@ -1,105 +0,0 @@
:orphan:
.. _usage_launch_distributed:
Launching Distributed Programs
==============================
.. currentmodule:: mlx.core.distributed
Installing the MLX python package provides a helper script ``mlx.launch`` that
can be used to run python scripts distributed on several nodes. It allows
launching using either the MPI backend or the ring backend. See the
:doc:`distributed docs <distributed>` for the different backends.
Usage
-----
The minimal usage example of ``mlx.launch`` is simply
.. code:: shell
mlx.launch --hosts ip1,ip2 my_script.py
or for testing on localhost
.. code:: shell
mlx.launch -n 2 my_script.py
The ``mlx.launch`` command connects to the provided host and launches the input
script on each host. It monitors each of the launched processes and terminates
the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
It also takes care of forwarding the output of each remote process to stdout
and stderr respectively.
Providing Hosts
^^^^^^^^^^^^^^^^
Hosts can be provided as command line arguments, like above, but the way that
allows to fully define a list of hosts is via a JSON hostfile. The hostfile has
a very simple schema. It is simply a list of objects that define each host via
a hostname to ssh to and a list of IPs to utilize for the communication.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]},
{"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]}
]
You can use ``mlx.distributed_config --over ethernet`` to create a hostfile
with IPs corresponding to the ``en0`` interface.
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^^
In order to be able to launch the script on each host we need to be able to
connect via ssh. Moreover the input script and python binary need to be on each
host and on the same path. A good checklist to debug errors is the following:
* ``ssh hostname`` works without asking for password or host confirmation
* the python binary is available on all hosts at the same path. You can use
``mlx.launch --print-python`` to see what that path is.
* the script you want to run is available on all hosts at the same path
.. _mpi_specifics:
MPI Specifics
-------------
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
* The IPs in the hostfile are ignored
* The ssh connectivity requirement is stronger as every node needs to be able
to connect to every other node
* ``mpirun`` needs to be available on every node at the same path
Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance
to choose a specific interface for the byte-transfer-layer of MPI we can call
``mlx.launch`` as follows:
.. code:: shell
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
.. _ring_specifics:
Ring Specifics
--------------
The ring backend, which is also the default backend, can be explicitly selected
with the argument ``--backend ring``. The ring backend has some specific
requirements and arguments that are different to MPI:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.

View File

@@ -21,13 +21,11 @@ Let's convert an array to NumPy and back.
.. note:: .. note::
Since NumPy does not support ``bfloat16`` arrays, you will need to convert Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``. ``np.array(a.astype(mx.float32))``.
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
buffer format string does not match the dtype V item size 0.``
By default, NumPy copies data to a new array. This can be prevented by creating By default, NumPy copies data to a new array. This can be prevented by creating an array view:
an array view:
.. code-block:: python .. code-block:: python
@@ -37,16 +35,10 @@ an array view:
a_view[0] = 1 a_view[0] = 1
print(a[0].item()) # 1 print(a[0].item()) # 1
.. note:: A NumPy array view is a normal NumPy array, except that it does not own its memory.
This means writing to the view is reflected in the original array.
NumPy arrays with type ``float64`` will be default converted to MLX arrays While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
with type ``float32``.
A NumPy array view is a normal NumPy array, except that it does not own its
memory. This means writing to the view is reflected in the original array.
While this is quite powerful to prevent copying arrays, it should be noted that
external changes to the memory of arrays cannot be reflected in gradients.
Let's demonstrate this in an example: Let's demonstrate this in an example:
@@ -64,12 +56,11 @@ Let's demonstrate this in an example:
The function ``f`` indirectly modifies the array ``x`` through a memory view. The function ``f`` indirectly modifies the array ``x`` through a memory view.
However, this modification is not reflected in the gradient, as seen in the However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
last line outputting ``1.0``, representing the gradient of the sum operation representing the gradient of the sum operation alone.
alone. The squaring of ``x`` occurs externally to MLX, meaning that no The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
gradient is incorporated. It's important to note that a similar issue arises It's important to note that a similar issue arises during array conversion and copying.
during array conversion and copying. For instance, a function defined as For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
even though no in-place operations on MLX memory are executed. even though no in-place operations on MLX memory are executed.
PyTorch PyTorch
@@ -80,8 +71,7 @@ PyTorch
PyTorch Support for :obj:`memoryview` is experimental and can break for PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now. multi-dimensional arrays. Casting to NumPy first is advised for now.
PyTorch supports the buffer protocol, but it requires an explicit PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
:obj:`memoryview`.
.. code-block:: python .. code-block:: python
@@ -92,8 +82,7 @@ PyTorch supports the buffer protocol, but it requires an explicit
b = torch.tensor(memoryview(a)) b = torch.tensor(memoryview(a))
c = mx.array(b.numpy()) c = mx.array(b.numpy())
Conversion from PyTorch tensors back to arrays must be done via intermediate Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
NumPy arrays with ``numpy()``.
JAX JAX
--- ---
@@ -111,8 +100,7 @@ JAX fully supports the buffer protocol.
TensorFlow TensorFlow
---------- ----------
TensorFlow supports the buffer protocol, but it requires an explicit TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
:obj:`memoryview`.
.. code-block:: python .. code-block:: python

View File

@@ -1,22 +0,0 @@
cmake_minimum_required(VERSION 3.27)
project(example LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
# Comment the following two commands only the MLX C++ library is installed and
# set(MLX_ROOT "/path/to/mlx") directly if needed.
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
add_executable(example example.cpp)
target_link_libraries(example PRIVATE mlx)

View File

@@ -1,26 +0,0 @@
## Build and Run
Install MLX with Python:
```bash
pip install mlx>=0.22
```
Build the C++ example:
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
Run the C++ example:
```
./build/example
```
which should output:
```
array([2, 4, 6], dtype=int32)
```

View File

@@ -1,14 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
namespace mx = mlx::core;
int main() {
auto x = mx::array({1, 2, 3});
auto y = mx::array({1, 2, 3});
std::cout << x + y << std::endl;
return 0;
}

View File

@@ -4,19 +4,19 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
namespace mx = mlx::core; using namespace mlx::core;
int main() { int main() {
if (!mx::distributed::is_available()) { if (!distributed::is_available()) {
std::cout << "No communication backend found" << std::endl; std::cout << "No communication backend found" << std::endl;
return 1; return 1;
} }
auto global_group = mx::distributed::init(); auto global_group = distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl; std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
mx::array x = mx::ones({10}); array x = ones({10});
mx::array out = mx::distributed::all_sum(x, global_group); array out = distributed::all_sum(x, global_group);
std::cout << out << std::endl; std::cout << out << std::endl;
} }

View File

@@ -10,7 +10,7 @@
/** /**
* An example of linear regression with MLX. * An example of linear regression with MLX.
*/ */
namespace mx = mlx::core; using namespace mlx::core;
int main() { int main() {
int num_features = 100; int num_features = 100;
@@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.01; float learning_rate = 0.01;
// True parameters // True parameters
auto w_star = mx::random::normal({num_features}); auto w_star = random::normal({num_features});
// The input examples (design matrix) // The input examples (design matrix)
auto X = mx::random::normal({num_examples, num_features}); auto X = random::normal({num_examples, num_features});
// Noisy labels // Noisy labels
auto eps = 1e-2 * mx::random::normal({num_examples}); auto eps = 1e-2 * random::normal({num_examples});
auto y = mx::matmul(X, w_star) + eps; auto y = matmul(X, w_star) + eps;
// Initialize random parameters // Initialize random parameters
mx::array w = 1e-2 * mx::random::normal({num_features}); array w = 1e-2 * random::normal({num_features});
auto loss_fn = [&](mx::array w) { auto loss_fn = [&](array w) {
auto yhat = mx::matmul(X, w); auto yhat = matmul(X, w);
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y)); return (0.5f / num_examples) * sum(square(yhat - y));
}; };
auto grad_fn = mx::grad(loss_fn); auto grad_fn = grad(loss_fn);
auto tic = timer::time(); auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) { for (int it = 0; it < num_iters; ++it) {
auto grads = grad_fn(w); auto grad = grad_fn(w);
w = w - learning_rate * grads; w = w - learning_rate * grad;
mx::eval(w); eval(w);
} }
auto toc = timer::time(); auto toc = timer::time();
auto loss = loss_fn(w); auto loss = loss_fn(w);
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>()); auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
auto throughput = num_iters / timer::seconds(toc - tic); auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
<< ", Throughput " << throughput << " (it/s)." << std::endl; << ", Throughput " << throughput << " (it/s)." << std::endl;

View File

@@ -10,7 +10,7 @@
/** /**
* An example of logistic regression with MLX. * An example of logistic regression with MLX.
*/ */
namespace mx = mlx::core; using namespace mlx::core;
int main() { int main() {
int num_features = 100; int num_features = 100;
@@ -19,35 +19,35 @@ int main() {
float learning_rate = 0.1; float learning_rate = 0.1;
// True parameters // True parameters
auto w_star = mx::random::normal({num_features}); auto w_star = random::normal({num_features});
// The input examples // The input examples
auto X = mx::random::normal({num_examples, num_features}); auto X = random::normal({num_examples, num_features});
// Labels // Labels
auto y = mx::matmul(X, w_star) > 0; auto y = matmul(X, w_star) > 0;
// Initialize random parameters // Initialize random parameters
mx::array w = 1e-2 * mx::random::normal({num_features}); array w = 1e-2 * random::normal({num_features});
auto loss_fn = [&](mx::array w) { auto loss_fn = [&](array w) {
auto logits = mx::matmul(X, w); auto logits = matmul(X, w);
auto scale = (1.0f / num_examples); auto scale = (1.0f / num_examples);
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits); return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
}; };
auto grad_fn = mx::grad(loss_fn); auto grad_fn = grad(loss_fn);
auto tic = timer::time(); auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) { for (int it = 0; it < num_iters; ++it) {
auto grads = grad_fn(w); auto grad = grad_fn(w);
w = w - learning_rate * grads; w = w - learning_rate * grad;
mx::eval(w); eval(w);
} }
auto toc = timer::time(); auto toc = timer::time();
auto loss = loss_fn(w); auto loss = loss_fn(w);
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples; auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
auto throughput = num_iters / timer::seconds(toc - tic); auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput " std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
<< throughput << " (it/s)." << std::endl; << throughput << " (it/s)." << std::endl;

View File

@@ -5,27 +5,27 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
namespace mx = mlx::core; using namespace mlx::core;
int main() { int main() {
// To use Metal debugging and profiling: // To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON). // 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1. // 2. Run with MTL_CAPTURE_ENABLED=1.
mx::metal::start_capture("mlx_trace.gputrace"); metal::start_capture("mlx_trace.gputrace");
// Start at index two because the default GPU and CPU streams have indices // Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each // zero and one, respectively. This naming matches the label assigned to each
// stream's command queue. // stream's command queue.
auto s2 = new_stream(mx::Device::gpu); auto s2 = new_stream(Device::gpu);
auto s3 = new_stream(mx::Device::gpu); auto s3 = new_stream(Device::gpu);
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2); auto a = arange(1.f, 10.f, 1.f, float32, s2);
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3); auto b = arange(1.f, 10.f, 1.f, float32, s3);
auto x = mx::add(a, a, s2); auto x = add(a, a, s2);
auto y = mx::add(b, b, s3); auto y = add(b, b, s3);
// The multiply will happen on the default stream. // The multiply will happen on the default stream.
std::cout << mx::multiply(x, y) << std::endl; std::cout << multiply(x, y) << std::endl;
mx::metal::stop_capture(); metal::stop_capture();
} }

View File

@@ -5,11 +5,11 @@
#include "mlx/mlx.h" #include "mlx/mlx.h"
namespace mx = mlx::core; using namespace mlx::core;
void array_basics() { void array_basics() {
// Make a scalar array: // Make a scalar array:
mx::array x(1.0); array x(1.0);
// Get the value out of it: // Get the value out of it:
auto s = x.item<float>(); auto s = x.item<float>();
@@ -29,31 +29,31 @@ void array_basics() {
// The datatype should be float32: // The datatype should be float32:
auto dtype = x.dtype(); auto dtype = x.dtype();
assert(dtype == mx::float32); assert(dtype == float32);
// Specify the dtype when constructing the array: // Specify the dtype when constructing the array:
x = mx::array(1, mx::int32); x = array(1, int32);
assert(x.dtype() == mx::int32); assert(x.dtype() == int32);
x.item<int>(); // OK x.item<int>(); // OK
// x.item<float>(); // Undefined! // x.item<float>(); // Undefined!
// Make a multidimensional array: // Make a multidimensional array:
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
// mlx is row-major by default so the first row of this array // mlx is row-major by default so the first row of this array
// is [1.0, 2.0] and the second row is [3.0, 4.0] // is [1.0, 2.0] and the second row is [3.0, 4.0]
// Make an array of shape {2, 2} filled with ones: // Make an array of shape {2, 2} filled with ones:
auto y = mx::ones({2, 2}); auto y = ones({2, 2});
// Pointwise add x and y: // Pointwise add x and y:
auto z = mx::add(x, y); auto z = add(x, y);
// Same thing: // Same thing:
z = x + y; z = x + y;
// mlx is lazy by default. At this point `z` only // mlx is lazy by default. At this point `z` only
// has a shape and a type but no actual data: // has a shape and a type but no actual data:
assert(z.dtype() == mx::float32); assert(z.dtype() == float32);
assert(z.shape(0) == 2); assert(z.shape(0) == 2);
assert(z.shape(1) == 2); assert(z.shape(1) == 2);
@@ -63,33 +63,33 @@ void array_basics() {
// and inputs. When `eval` is called on an array (or arrays), the array and // and inputs. When `eval` is called on an array (or arrays), the array and
// all of its dependencies are recursively evaluated to produce the result. // all of its dependencies are recursively evaluated to produce the result.
// Once an array is evaluated, it has data and is detached from its inputs. // Once an array is evaluated, it has data and is detached from its inputs.
mx::eval(z); eval(z);
// Of course the array can still be an input to other operations. You can // Of course the array can still be an input to other operations. You can even
// even call eval on the array again, this will just be a no-op: // call eval on the array again, this will just be a no-op:
mx::eval(z); // no-op eval(z); // no-op
// Some functions or methods on arrays implicitly evaluate them. For example // Some functions or methods on arrays implicitly evaluate them. For example
// accessing a value in an array or printing the array implicitly evaluate it: // accessing a value in an array or printing the array implicitly evaluate it:
z = mx::ones({1}); z = ones({1});
z.item<float>(); // implicit evaluation z.item<float>(); // implicit evaluation
z = mx::ones({2, 2}); z = ones({2, 2});
std::cout << z << std::endl; // implicit evaluation std::cout << z << std::endl; // implicit evaluation
} }
void automatic_differentiation() { void automatic_differentiation() {
auto fn = [](mx::array x) { return mx::square(x); }; auto fn = [](array x) { return square(x); };
// Computing the derivative function of a function // Computing the derivative function of a function
auto grad_fn = mx::grad(fn); auto grad_fn = grad(fn);
// Call grad_fn on the input to get the derivative // Call grad_fn on the input to get the derivative
auto x = mx::array(1.5); auto x = array(1.5);
auto dfdx = grad_fn(x); auto dfdx = grad_fn(x);
// dfdx is 2 * x // dfdx is 2 * x
// Get the second derivative by composing grad with grad // Get the second derivative by composing grad with grad
auto d2fdx2 = mx::grad(mx::grad(fn))(x); auto d2fdx2 = grad(grad(fn))(x);
// d2fdx2 is 2 // d2fdx2 is 2
} }

View File

@@ -1,22 +0,0 @@
cmake_minimum_required(VERSION 3.27)
project(import_mlx LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
add_executable(eval_mlp eval_mlp.cpp)
target_link_libraries(eval_mlp PRIVATE mlx)
add_executable(train_mlp train_mlp.cpp)
target_link_libraries(train_mlp PRIVATE mlx)

View File

@@ -1,49 +0,0 @@
## Setup
Install MLX:
```bash
pip install mlx>=0.22
```
Build the C++ examples:
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
## Run
### Eval MLP
Run the Python script to export the eval function:
```bash
python eval_mlp.py
```
Then run the C++ program to import and run the function:
```
./build/eval_mlp
```
The Python and C++ programs should output the same result.
### Train MLP
Run the Python script to export the model initialization and training
functions:
```bash
python train_mlp.py
```
Then run the C++ program to import and run the functions:
```
./build/train_mlp
```
The Python and C++ programs should output the same results.

View File

@@ -1,25 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <mlx/mlx.h>
#include <iostream>
namespace mx = mlx::core;
int main() {
int batch_size = 8;
int input_dim = 32;
// Make the input
mx::random::seed(42);
auto example_x = mx::random::uniform({batch_size, input_dim});
// Import the function
auto forward = mx::import_function("eval_mlp.mlxfn");
// Call the imported function
auto out = forward({example_x})[0];
std::cout << out << std::endl;
return 0;
}

View File

@@ -1,52 +0,0 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
class MLP(nn.Module):
"""A simple MLP."""
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.relu(l(x))
return self.layers[-1](x)
if __name__ == "__main__":
batch_size = 8
input_dim = 32
output_dim = 10
# Load the model
mx.random.seed(0) # Seed for params
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
mx.eval(model)
# Note, the model parameters are saved in the export function
def forward(x):
return model(x)
mx.random.seed(42) # Seed for input
example_x = mx.random.uniform(shape=(batch_size, input_dim))
mx.export_function("eval_mlp.mlxfn", forward, example_x)
# Import in Python
imported_forward = mx.import_function("eval_mlp.mlxfn")
expected = forward(example_x)
(out,) = imported_forward(example_x)
assert mx.allclose(expected, out)
print(out)

View File

@@ -1,35 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <mlx/mlx.h>
#include <iostream>
namespace mx = mlx::core;
int main() {
int batch_size = 8;
int input_dim = 32;
int output_dim = 10;
auto state = mx::import_function("init_mlp.mlxfn")({});
// Make the input
mx::random::seed(42);
auto example_X = mx::random::normal({batch_size, input_dim});
auto example_y = mx::random::randint(0, output_dim, {batch_size});
// Import the function
auto step = mx::import_function("train_mlp.mlxfn");
// Call the imported function
for (int it = 0; it < 100; ++it) {
state.insert(state.end(), {example_X, example_y});
state = step(state);
eval(state);
auto loss = state.back();
state.pop_back();
if (it % 10 == 0) {
std::cout << "Loss " << loss.item<float>() << std::endl;
}
}
return 0;
}

View File

@@ -1,76 +0,0 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.utils
class MLP(nn.Module):
"""A simple MLP."""
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.relu(l(x))
return self.layers[-1](x)
if __name__ == "__main__":
batch_size = 8
input_dim = 32
output_dim = 10
def init():
# Seed for the parameter initialization
mx.random.seed(0)
model = MLP(
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
)
optimizer = optim.SGD(learning_rate=1e-1)
optimizer.init(model.parameters())
state = [model.parameters(), optimizer.state]
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
return model, optimizer, tree_structure, state
# Export the model parameter initialization
model, optimizer, tree_structure, state = init()
mx.eval(state)
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
def loss_fn(params, X, y):
model.update(params)
return nn.losses.cross_entropy(model(X), y, reduction="mean")
def step(*inputs):
*state, X, y = inputs
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
optimizer.state = opt_state
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
params = optimizer.apply_gradients(grads, params)
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
return *state, loss
# Make some random data
mx.random.seed(42)
example_X = mx.random.normal(shape=(batch_size, input_dim))
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
# Export one step of SGD
imported_step = mx.import_function("train_mlp.mlxfn")
for it in range(100):
*state, loss = imported_step(*state, example_X, example_y)
if it % 10 == 0:
print(f"Loss {loss.item():.6}")

View File

@@ -10,6 +10,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON) option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies ----------------------------- # ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED)
find_package( find_package(
Python 3.8 Python 3.8
COMPONENTS Interpreter Development.Module COMPONENTS Interpreter Development.Module
@@ -17,15 +18,10 @@ find_package(
execute_process( execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE nanobind_ROOT) OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED) find_package(nanobind CONFIG REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
# ----------------------------- Extensions ----------------------------- # ----------------------------- Extensions -----------------------------
# Add library # Add library

View File

@@ -1,20 +1,25 @@
// Copyright © 2023-2025 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include "axpby/axpby.h" #include "axpby/axpby.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h>
#endif
#ifdef _METAL_ #ifdef _METAL_
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#endif #endif
namespace my_ext { namespace mlx::core {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Operation Implementation // Operation Implementation
@@ -27,24 +32,24 @@ namespace my_ext {
* Follow numpy style broadcasting between x and y * Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed * Inputs are upcasted to floats if needed
**/ **/
mx::array axpby( array axpby(
const mx::array& x, // Input mx::array x const array& x, // Input array x
const mx::array& y, // Input mx::array y const array& y, // Input array y
const float alpha, // Scaling factor for x const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y const float beta, // Scaling factor for y
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) { ) {
// Promote dtypes between x and y as needed // Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype()); auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y // Upcast to float32 for non-floating point inputs x and y
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32) auto out_dtype = issubdtype(promoted_dtype, float32)
? promoted_dtype ? promoted_dtype
: promote_types(promoted_dtype, mx::float32); : promote_types(promoted_dtype, float32);
// Cast x and y up to the determined dtype (on the same stream s) // Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = mx::astype(x, out_dtype, s); auto x_casted = astype(x, out_dtype, s);
auto y_casted = mx::astype(y, out_dtype, s); auto y_casted = astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s) // Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s); auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
@@ -52,12 +57,12 @@ mx::array axpby(
// Construct the array as the output of the Axpby primitive // Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs // with the broadcasted and upcasted arrays as inputs
return mx::array( return array(
/* const mx::Shape& shape = */ out_shape, /* const std::vector<int>& shape = */ out_shape,
/* mx::Dtype dtype = */ out_dtype, /* Dtype dtype = */ out_dtype,
/* std::shared_ptr<mx::Primitive> primitive = */ /* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta), std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs); /* const std::vector<array>& inputs = */ broadcasted_inputs);
} }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -66,71 +71,140 @@ mx::array axpby(
template <typename T> template <typename T>
void axpby_impl( void axpby_impl(
const mx::array& x, const array& x,
const mx::array& y, const array& y,
mx::array& out, array& out,
float alpha_, float alpha_,
float beta_, float beta_) {
mx::Stream stream) { // We only allocate memory when we are ready to fill the output
// Allocate the output with `malloc_or_wait` which synchronously allocates // malloc_or_wait synchronously allocates available memory
// memory, potentially waiting if the system is under memory pressure // There may be a wait executed here if the allocation is requested
out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); // under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get the CPU command encoder and register input and output arrays // Collect input and output data pointers
auto& encoder = mx::cpu::get_command_encoder(stream); const T* x_ptr = x.data<T>();
encoder.set_input_array(x); const T* y_ptr = y.data<T>();
encoder.set_input_array(y); T* out_ptr = out.data<T>();
encoder.set_output_array(out);
// Launch the CPU kernel // Cast alpha and beta to the relevant types
encoder.dispatch([x_ptr = x.data<T>(), T alpha = static_cast<T>(alpha_);
y_ptr = y.data<T>(), T beta = static_cast<T>(beta_);
out_ptr = out.data<T>(),
size = out.size(),
shape = out.shape(),
x_strides = x.strides(),
y_strides = y.strides(),
alpha_,
beta_]() {
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the element-wise operation for each output // Do the element-wise operation for each output
for (size_t out_idx = 0; out_idx < size; out_idx++) { for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y // Map linear indices to offsets in x and y
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides); auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides); auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided // We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additional mapping // (defaults to row major) and hence it doesn't need additional mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
} }
});
} }
void Axpby::eval_cpu( /** Fall back implementation for evaluation on CPU */
const std::vector<mx::array>& inputs, void Axpby::eval(
std::vector<mx::array>& outputs) { const std::vector<array>& inputs,
std::vector<array>& outputs) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0]; auto& out = outputs[0];
// Dispatch to the correct dtype // Dispatch to the correct dtype
if (out.dtype() == mx::float32) { if (out.dtype() == float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_, stream()); return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::float16) { } else if (out.dtype() == float16) {
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream()); return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::bfloat16) { } else if (out.dtype() == bfloat16) {
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream()); return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == mx::complex64) { } else if (out.dtype() == complex64) {
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream()); return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else { } else {
throw std::runtime_error( throw std::runtime_error(
"Axpby is only supported for floating point types."); "Axpby is only supported for floating point types.");
} }
} }
///////////////////////////////////////////////////////////////////////////////
// Primitive Accelerate Backend Implementation
///////////////////////////////////////////////////////////////////////////////
#ifdef ACCELERATE_NEW_LAPACK
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common backend if specializations are not available
eval(inputs, outputs);
}
#else // Accelerate not available
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
eval(inputs, outputs);
}
#endif
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Primitive Metal Backend Implementation // Primitive Metal Backend Implementation
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -139,9 +213,10 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */ /** Evaluate primitive on GPU */
void Axpby::eval_gpu( void Axpby::eval_gpu(
const std::vector<mx::array>& inputs, const std::vector<array>& inputs,
std::vector<mx::array>& outputs) { std::vector<array>& outputs) {
// Prepare inputs // Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0]; auto& x = inputs[0];
auto& y = inputs[1]; auto& y = inputs[1];
auto& out = outputs[0]; auto& out = outputs[0];
@@ -150,7 +225,7 @@ void Axpby::eval_gpu(
// and each stream carries its device identifiers // and each stream carries its device identifiers
auto& s = stream(); auto& s = stream();
// We get the needed metal device using the stream // We get the needed metal device using the stream
auto& d = mx::metal::device(s.device); auto& d = metal::device(s.device);
// Prepare to specialize based on contiguity // Prepare to specialize based on contiguity
bool contiguous_kernel = bool contiguous_kernel =
@@ -160,12 +235,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization // Allocate output memory with strides based on specialization
if (contiguous_kernel) { if (contiguous_kernel) {
out.set_data( out.set_data(
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()), allocator::malloc_or_wait(x.data_size() * out.itemsize()),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
} else { } else {
out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
} }
// Resolve name of kernel (corresponds to axpby.metal) // Resolve name of kernel (corresponds to axpby.metal)
@@ -182,7 +257,7 @@ void Axpby::eval_gpu(
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to // Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal // those in the kernel declaration at axpby.metal
@@ -197,15 +272,15 @@ void Axpby::eval_gpu(
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode alpha and beta // Encode alpha and beta
compute_encoder.set_bytes(alpha_, 3); compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder.set_bytes(beta_, 4); compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim if needed // Encode shape, strides and ndim if needed
if (!contiguous_kernel) { if (!contiguous_kernel) {
compute_encoder.set_vector_bytes(x.shape(), 5); compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder.set_vector_bytes(x.strides(), 6); compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder.set_vector_bytes(y.strides(), 7); compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder.set_bytes(ndim, 8); compute_encoder->setBytes(&ndim, sizeof(int), 8);
} }
// We launch 1 thread for each input and make sure that the number of // We launch 1 thread for each input and make sure that the number of
@@ -220,15 +295,15 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among // Launch the grid with the given number of threads divided among
// the given threadgroups // the given threadgroups
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
#else // Metal is not available #else // Metal is not available
/** Fail evaluation on GPU */ /** Fail evaluation on GPU */
void Axpby::eval_gpu( void Axpby::eval_gpu(
const std::vector<mx::array>& inputs, const std::vector<array>& inputs,
std::vector<mx::array>& out) { std::vector<array>& out) {
throw std::runtime_error("Axpby has no GPU implementation."); throw std::runtime_error("Axpby has no GPU implementation.");
} }
@@ -239,9 +314,9 @@ void Axpby::eval_gpu(
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
/** The Jacobian-vector product. */ /** The Jacobian-vector product. */
std::vector<mx::array> Axpby::jvp( std::vector<array> Axpby::jvp(
const std::vector<mx::array>& primals, const std::vector<array>& primals,
const std::vector<mx::array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents // Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops // The jvp transform on the primitive can built with ops
@@ -253,8 +328,8 @@ std::vector<mx::array> Axpby::jvp(
// scaled by beta // scaled by beta
if (argnums.size() > 1) { if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_; auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = mx::array(scale, tangents[0].dtype()); auto scale_arr = array(scale, tangents[0].dtype());
return {mx::multiply(scale_arr, tangents[0], stream())}; return {multiply(scale_arr, tangents[0], stream())};
} }
// If, argnums = {0, 1}, we take contributions from both // If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta // which gives us jvp = tangent_x * alpha + tangent_y * beta
@@ -264,24 +339,24 @@ std::vector<mx::array> Axpby::jvp(
} }
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<mx::array> Axpby::vjp( std::vector<array> Axpby::vjp(
const std::vector<mx::array>& primals, const std::vector<array>& primals,
const std::vector<mx::array>& cotangents, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<mx::array>&) { const std::vector<array>&) {
// Reverse mode diff // Reverse mode diff
std::vector<mx::array> vjps; std::vector<array> vjps;
for (auto arg : argnums) { for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_; auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = mx::array(scale, cotangents[0].dtype()); auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream())); vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
} }
return vjps; return vjps;
} }
/** Vectorize primitive along given axis */ /** Vectorize primitive along given axis */
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap( std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
const std::vector<mx::array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation."); throw std::runtime_error("Axpby has no vmap implementation.");
} }
@@ -292,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_; return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
} }
} // namespace my_ext } // namespace mlx::core

View File

@@ -1,13 +1,11 @@
// Copyright © 2023-2025 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mx = mlx::core; namespace mlx::core {
namespace my_ext {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Operation // Operation
@@ -20,22 +18,22 @@ namespace my_ext {
* Follow numpy style broadcasting between x and y * Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed * Inputs are upcasted to floats if needed
**/ **/
mx::array axpby( array axpby(
const mx::array& x, // Input array x const array& x, // Input array x
const mx::array& y, // Input array y const array& y, // Input array y
const float alpha, // Scaling factor for x const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y const float beta, // Scaling factor for y
mx::StreamOrDevice s = {} // Stream on which to schedule the operation StreamOrDevice s = {} // Stream on which to schedule the operation
); );
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Primitive // Primitive
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
class Axpby : public mx::Primitive { class Axpby : public Primitive {
public: public:
explicit Axpby(mx::Stream stream, float alpha, float beta) explicit Axpby(Stream stream, float alpha, float beta)
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {}; : Primitive(stream), alpha_(alpha), beta_(beta) {};
/** /**
* A primitive must know how to evaluate itself on the CPU/GPU * A primitive must know how to evaluate itself on the CPU/GPU
@@ -44,25 +42,23 @@ class Axpby : public mx::Primitive {
* To avoid unnecessary allocations, the evaluation function * To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array. * is responsible for allocating space for the array.
*/ */
void eval_cpu( void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
const std::vector<mx::array>& inputs, override;
std::vector<mx::array>& outputs) override; void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
void eval_gpu( override;
const std::vector<mx::array>& inputs,
std::vector<mx::array>& outputs) override;
/** The Jacobian-vector product. */ /** The Jacobian-vector product. */
std::vector<mx::array> jvp( std::vector<array> jvp(
const std::vector<mx::array>& primals, const std::vector<array>& primals,
const std::vector<mx::array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) override; const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<mx::array> vjp( std::vector<array> vjp(
const std::vector<mx::array>& primals, const std::vector<array>& primals,
const std::vector<mx::array>& cotangents, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<mx::array>& outputs) override; const std::vector<array>& outputs) override;
/** /**
* The primitive must know how to vectorize itself across * The primitive must know how to vectorize itself across
@@ -70,8 +66,8 @@ class Axpby : public mx::Primitive {
* representing the vectorized computation and the axis which * representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension. * corresponds to the output vectorized dimension.
*/ */
std::pair<std::vector<mx::array>, std::vector<int>> vmap( std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<mx::array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) override; const std::vector<int>& axes) override;
/** Print the primitive. */ /** Print the primitive. */
@@ -80,11 +76,14 @@ class Axpby : public mx::Primitive {
} }
/** Equivalence check **/ /** Equivalence check **/
bool is_equivalent(const mx::Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
private: private:
float alpha_; float alpha_;
float beta_; float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
}; };
} // namespace my_ext } // namespace mlx::core

View File

@@ -1,7 +1,8 @@
// Copyright © 2023-2025 Apple Inc. // Copyright © 2023 Apple Inc.
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T> template <typename T>
@@ -12,8 +13,8 @@ template <typename T>
constant const float& alpha [[buffer(3)]], constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]], constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]], constant const int* shape [[buffer(5)]],
constant const int64_t* x_strides [[buffer(6)]], constant const size_t* x_strides [[buffer(6)]],
constant const int64_t* y_strides [[buffer(7)]], constant const size_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]], constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim); auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
@@ -34,14 +35,29 @@ template <typename T>
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index]; static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
} }
// clang-format off #define instantiate_axpby(type_name, type) \
#define instantiate_axpby(type_name, type) \ template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \ axpby_general<type>( \
instantiate_kernel( \ device const type* x [[buffer(0)]], \
"axpby_contiguous_" #type_name, axpby_contiguous, type) device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float); instantiate_axpby(float32, float);
instantiate_axpby(float16, half); instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t); instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t); instantiate_axpby(complex64, complex64_t);
// clang-format on

View File

@@ -8,12 +8,14 @@
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core;
NB_MODULE(_ext, m) { NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX"; m.doc() = "Sample extension for MLX";
m.def( m.def(
"axpby", "axpby",
&my_ext::axpby, &axpby,
"x"_a, "x"_a,
"y"_a, "y"_a,
"alpha"_a, "alpha"_a,

View File

@@ -1,8 +1,8 @@
[build-system] [build-system]
requires = [ requires = [
"setuptools>=42", "setuptools>=42",
"cmake>=3.25", "cmake>=3.24",
"mlx>=0.18.0", "mlx>=0.18.0",
"nanobind==2.4.0", "nanobind==2.2.0",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.25 cmake>=3.24
mlx>=0.21.0 mlx>=0.18.1
nanobind==2.2.0 nanobind==2.2.0

View File

@@ -28,19 +28,10 @@ endif()
if (@MLX_BUILD_METAL@) if (@MLX_BUILD_METAL@)
set(MLX_BUILD_METAL @MLX_BUILD_METAL@) set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_) set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
set(MLX_INCLUDE_DIRS set_and_check(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};" ${MLX_INCLUDE_DIRS}
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
) )
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
else()
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
endif()
endif() endif()
set_target_properties(mlx PROPERTIES set_target_properties(mlx PROPERTIES
@@ -49,4 +40,4 @@ set_target_properties(mlx PROPERTIES
) )
include(FindPackageHandleStandardArgs) include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS) find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)

View File

@@ -5,7 +5,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
@@ -19,31 +18,21 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
# Define MLX_VERSION only in the version.cpp file.
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
if(MSVC)
# Disable some MSVC warnings to speed up compilation.
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
endif()
if(WIN32)
# Export symbols by default to behave like macOS/linux.
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
if(MLX_BUILD_CPU) if(MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else() else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
endif() endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
elseif(MLX_BUILD_CPU)
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
endif()
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)

View File

@@ -19,7 +19,7 @@ Buffer malloc(size_t size) {
} }
void free(Buffer buffer) { void free(Buffer buffer) {
allocator().free(buffer); return allocator().free(buffer);
} }
Buffer CommonAllocator::malloc(size_t size, bool) { Buffer CommonAllocator::malloc(size_t size, bool) {

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <functional> #include <functional>
#include <unordered_map>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/ops.h" #include "mlx/ops.h"
@@ -10,14 +9,28 @@
namespace mlx::core { namespace mlx::core {
namespace {
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
bool in_tracing() {
return detail::InTracing::in_tracing();
}
bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */) array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
auto cval = static_cast<complex64_t>(val); auto cval = static_cast<complex64_t>(val);
init(&cval); init(&cval);
} }
array::array( array::array(
Shape shape, std::vector<int> shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs) std::vector<array> inputs)
@@ -25,21 +38,10 @@ array::array(
std::move(shape), std::move(shape),
dtype, dtype,
std::move(primitive), std::move(primitive),
std::move(inputs))) { std::move(inputs))) {}
if (has_primitive() && this->primitive().stream().device == Device::gpu) {
for (auto& in : this->inputs()) {
if (in.dtype() == float64) {
throw std::invalid_argument("float64 is not supported on the GPU");
}
}
if (this->dtype() == float64) {
throw std::invalid_argument("float64 is not supported on the GPU");
}
}
}
std::vector<array> array::make_arrays( std::vector<array> array::make_arrays(
std::vector<Shape> shapes, std::vector<std::vector<int>> shapes,
const std::vector<Dtype>& dtypes, const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive, const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) { const std::vector<array>& inputs) {
@@ -56,59 +58,47 @@ std::vector<array> array::make_arrays(
return outputs; return outputs;
} }
array array::unsafe_weak_copy(const array& other) {
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
cpy.set_data(
other.buffer(),
other.data_size(),
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
return cpy;
}
array::array(std::initializer_list<float> data) array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>( : array_desc_(std::make_shared<ArrayDesc>(
Shape{static_cast<ShapeElem>(data.size())}, std::vector<int>{static_cast<int>(data.size())},
float32)) { float32)) {
init(data.begin()); init(data.begin());
} }
array::array(std::initializer_list<int> data, Dtype dtype) array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>( : array_desc_(std::make_shared<ArrayDesc>(
Shape{static_cast<ShapeElem>(data.size())}, std::vector<int>{static_cast<int>(data.size())},
dtype)) { dtype)) {
init(data.begin()); init(data.begin());
} }
/* Build an array from a shared buffer */ /* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter) array::array(
allocator::Buffer data,
std::vector<int> shape,
Dtype dtype,
deleter_t deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, deleter); set_data(data, deleter);
} }
void array::detach() { void array::detach() {
array_desc_->primitive = nullptr;
for (auto& s : array_desc_->siblings) {
s.array_desc_->primitive = nullptr;
}
for (auto& s : array_desc_->siblings) { for (auto& s : array_desc_->siblings) {
s.array_desc_->inputs.clear(); s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear(); s.array_desc_->siblings.clear();
s.array_desc_->position = 0; s.array_desc_->position = 0;
s.array_desc_->primitive = nullptr;
} }
array_desc_->inputs.clear(); array_desc_->inputs.clear();
array_desc_->siblings.clear(); array_desc_->siblings.clear();
array_desc_->position = 0; array_desc_->position = 0;
array_desc_->primitive = nullptr;
} }
bool array::is_available() const { bool array::is_available() const {
if (status() == Status::available) { if (status() == Status::available) {
return true; return true;
} else if ( } else if (status() == Status::evaluated && event().is_signaled()) {
status() == Status::evaluated &&
(!event().valid() || event().is_signaled())) {
set_status(Status::available); set_status(Status::available);
return true; return true;
} }
@@ -117,10 +107,7 @@ bool array::is_available() const {
void array::wait() { void array::wait() {
if (!is_available()) { if (!is_available()) {
if (event().valid()) { event().wait();
event().wait();
detach_event();
}
set_status(Status::available); set_status(Status::available);
} }
} }
@@ -135,11 +122,10 @@ void array::eval() {
} }
bool array::is_tracer() const { bool array::is_tracer() const {
return (array_desc_->is_tracer && detail::in_tracing()) || return array_desc_->is_tracer && in_tracing() || retain_graph();
detail::retain_graph();
} }
void array::set_data(allocator::Buffer buffer, Deleter d) { void array::set_data(allocator::Buffer buffer, deleter_t d) {
array_desc_->data = std::make_shared<Data>(buffer, d); array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr(); array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = size(); array_desc_->data_size = size();
@@ -152,9 +138,9 @@ void array::set_data(allocator::Buffer buffer, Deleter d) {
void array::set_data( void array::set_data(
allocator::Buffer buffer, allocator::Buffer buffer,
size_t data_size, size_t data_size,
Strides strides, std::vector<size_t> strides,
Flags flags, Flags flags,
Deleter d) { deleter_t d) {
array_desc_->data = std::make_shared<Data>(buffer, d); array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr(); array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = data_size; array_desc_->data_size = data_size;
@@ -164,7 +150,7 @@ void array::set_data(
void array::copy_shared_buffer( void array::copy_shared_buffer(
const array& other, const array& other,
const Strides& strides, const std::vector<size_t>& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
size_t offset /* = 0 */) { size_t offset /* = 0 */) {
@@ -181,13 +167,34 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size()); copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
} }
void array::move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}
void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
array::~array() { array::~array() {
if (array_desc_ == nullptr) { if (array_desc_ == nullptr) {
return; return;
} }
// Detached/detaching // Ignore arrays that might be detached during eval
if (array_desc_->primitive == nullptr) { if (status() == array::Status::scheduled) {
return; return;
} }
@@ -207,8 +214,6 @@ array::~array() {
if (do_detach) { if (do_detach) {
for (auto& s : siblings()) { for (auto& s : siblings()) {
for (auto& ss : s.siblings()) { for (auto& ss : s.siblings()) {
// Set to null here to avoid descending into array destructor
// for siblings
ss.array_desc_ = nullptr; ss.array_desc_ = nullptr;
} }
s.array_desc_->siblings.clear(); s.array_desc_->siblings.clear();
@@ -229,13 +234,13 @@ void array::ArrayDesc::init() {
} }
} }
array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype) array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) { : shape(std::move(shape)), dtype(dtype), status(Status::available) {
init(); init();
} }
array::ArrayDesc::ArrayDesc( array::ArrayDesc::ArrayDesc(
Shape shape, std::vector<int> shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs) std::vector<array> inputs)
@@ -273,19 +278,7 @@ array::ArrayDesc::~ArrayDesc() {
} }
ad.inputs.clear(); ad.inputs.clear();
for (auto& [_, a] : input_map) { for (auto& [_, a] : input_map) {
bool is_deletable = if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
(a.array_desc_.use_count() <= a.siblings().size() + 1);
// An array with siblings is deletable only if all of its siblings
// are deletable
for (auto& s : a.siblings()) {
if (!is_deletable) {
break;
}
int is_input = (input_map.find(s.id()) != input_map.end());
is_deletable &=
s.array_desc_.use_count() <= a.siblings().size() + is_input;
}
if (is_deletable) {
for_deletion.push_back(std::move(a.array_desc_)); for_deletion.push_back(std::move(a.array_desc_));
} }
} }
@@ -299,14 +292,6 @@ array::ArrayDesc::~ArrayDesc() {
auto top = std::move(for_deletion.back()); auto top = std::move(for_deletion.back());
for_deletion.pop_back(); for_deletion.pop_back();
append_deletable_inputs(*top); append_deletable_inputs(*top);
// Clear out possible siblings to break circular references
for (auto& s : top->siblings) {
// Set to null here to avoid descending into top-level
// array destructor for siblings
s.array_desc_ = nullptr;
}
top->siblings.clear();
} }
} }
@@ -318,7 +303,7 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
} }
array::ArrayIterator::reference array::ArrayIterator::operator*() const { array::ArrayIterator::reference array::ArrayIterator::operator*() const {
auto start = Shape(arr.ndim(), 0); auto start = std::vector<int>(arr.ndim(), 0);
auto end = arr.shape(); auto end = arr.shape();
auto shape = arr.shape(); auto shape = arr.shape();
shape.erase(shape.begin()); shape.erase(shape.begin());

View File

@@ -15,11 +15,7 @@ namespace mlx::core {
// Forward declaration // Forward declaration
class Primitive; class Primitive;
using deleter_t = std::function<void(allocator::Buffer)>;
using Deleter = std::function<void(allocator::Buffer)>;
using ShapeElem = int32_t;
using Shape = std::vector<ShapeElem>;
using Strides = std::vector<int64_t>;
class array { class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc /* An array is really a node in a graph. It contains a shared ArrayDesc
@@ -35,33 +31,33 @@ class array {
explicit array(const std::complex<float>& val, Dtype dtype = complex64); explicit array(const std::complex<float>& val, Dtype dtype = complex64);
template <typename It> template <typename It>
explicit array( array(
It data, It data,
Shape shape, std::vector<int> shape,
Dtype dtype = Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>()); TypeToDtype<typename std::iterator_traits<It>::value_type>());
template <typename T> template <typename T>
explicit array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>()); array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
/* Special case so empty lists default to float32. */ /* Special case so empty lists default to float32. */
explicit array(std::initializer_list<float> data); array(std::initializer_list<float> data);
/* Special case so array({}, type) is an empty array. */ /* Special case so array({}, type) is an empty array. */
explicit array(std::initializer_list<int> data, Dtype dtype); array(std::initializer_list<int> data, Dtype dtype);
template <typename T> template <typename T>
explicit array( array(
std::initializer_list<T> data, std::initializer_list<T> data,
Shape shape, std::vector<int> shape,
Dtype dtype = TypeToDtype<T>()); Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */ /* Build an array from a buffer */
explicit array( array(
allocator::Buffer data, allocator::Buffer data,
Shape shape, std::vector<int> shape,
Dtype dtype, Dtype dtype,
Deleter deleter = allocator::free); deleter_t deleter = allocator::free);
/** Assignment to rvalue does not compile. */ /** Assignment to rvalue does not compile. */
array& operator=(const array& other) && = delete; array& operator=(const array& other) && = delete;
@@ -100,7 +96,7 @@ class array {
} }
/** The shape of the array as a vector of integers. */ /** The shape of the array as a vector of integers. */
const Shape& shape() const { const std::vector<int>& shape() const {
return array_desc_->shape; return array_desc_->shape;
} }
@@ -109,12 +105,12 @@ class array {
* *
* This function supports negative indexing and provides * This function supports negative indexing and provides
* bounds checking. */ * bounds checking. */
auto shape(int dim) const { int shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim); return shape().at(dim < 0 ? dim + ndim() : dim);
} }
/** The strides of the array. */ /** The strides of the array. */
const Strides& strides() const { const std::vector<size_t>& strides() const {
return array_desc_->strides; return array_desc_->strides;
} }
@@ -123,7 +119,7 @@ class array {
* *
* This function supports negative indexing and provides * This function supports negative indexing and provides
* bounds checking. */ * bounds checking. */
auto strides(int dim) const { size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim); return strides().at(dim < 0 ? dim + ndim() : dim);
} }
@@ -188,24 +184,17 @@ class array {
*/ */
array( array(
Shape shape, std::vector<int> shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs); std::vector<array> inputs);
static std::vector<array> make_arrays( static std::vector<array> make_arrays(
std::vector<Shape> shapes, std::vector<std::vector<int>> shapes,
const std::vector<Dtype>& dtypes, const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive, const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs); const std::vector<array>& inputs);
/**
* Get a new array that refers to the same data as the input but with a
* non-owning pointer to it. Note the array is detached from the graph and has
* no inputs, siblings or primitive.
*/
static array unsafe_weak_copy(const array& other);
/** A unique identifier for an array. */ /** A unique identifier for an array. */
std::uintptr_t id() const { std::uintptr_t id() const {
return reinterpret_cast<std::uintptr_t>(array_desc_.get()); return reinterpret_cast<std::uintptr_t>(array_desc_.get());
@@ -218,8 +207,8 @@ class array {
struct Data { struct Data {
allocator::Buffer buffer; allocator::Buffer buffer;
Deleter d; deleter_t d;
Data(allocator::Buffer buffer, Deleter d = allocator::free) Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d) {} : buffer(buffer), d(d) {}
// Not copyable // Not copyable
Data(const Data& d) = delete; Data(const Data& d) = delete;
@@ -360,6 +349,11 @@ class array {
// For example, the status of `x` in `auto x = a + b`. // For example, the status of `x` in `auto x = a + b`.
unscheduled, unscheduled,
// The ouptut of a computation which has been scheduled but `eval_*` has
// not yet been called on the array's primitive. A possible
// status of `x` in `auto x = a + b; eval(x);`
scheduled,
// The array's `eval_*` function has been run, but the computation is not // The array's `eval_*` function has been run, but the computation is not
// necessarily complete. The array will have memory allocated and if it is // necessarily complete. The array will have memory allocated and if it is
// not a tracer then it will be detached from the graph. // not a tracer then it will be detached from the graph.
@@ -396,10 +390,6 @@ class array {
array_desc_->event = std::move(e); array_desc_->event = std::move(e);
} }
void detach_event() const {
array_desc_->event = Event{};
}
// Mark the array as a tracer array (true) or not. // Mark the array as a tracer array (true) or not.
void set_tracer(bool is_tracer) { void set_tracer(bool is_tracer) {
array_desc_->is_tracer = is_tracer; array_desc_->is_tracer = is_tracer;
@@ -407,24 +397,33 @@ class array {
// Check if the array is a tracer array // Check if the array is a tracer array
bool is_tracer() const; bool is_tracer() const;
void set_data(allocator::Buffer buffer, Deleter d = allocator::free); void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
void set_data( void set_data(
allocator::Buffer buffer, allocator::Buffer buffer,
size_t data_size, size_t data_size,
Strides strides, std::vector<size_t> strides,
Flags flags, Flags flags,
Deleter d = allocator::free); deleter_t d = allocator::free);
void copy_shared_buffer( void copy_shared_buffer(
const array& other, const array& other,
const Strides& strides, const std::vector<size_t>& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
size_t offset = 0); size_t offset = 0);
void copy_shared_buffer(const array& other); void copy_shared_buffer(const array& other);
void move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) { void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_; array_desc_ = other.array_desc_;
} }
@@ -437,8 +436,8 @@ class array {
void init(const It src); void init(const It src);
struct ArrayDesc { struct ArrayDesc {
Shape shape; std::vector<int> shape;
Strides strides; std::vector<size_t> strides;
size_t size; size_t size;
Dtype dtype; Dtype dtype;
std::shared_ptr<Primitive> primitive; std::shared_ptr<Primitive> primitive;
@@ -472,10 +471,10 @@ class array {
// The arrays position in the output list // The arrays position in the output list
uint32_t position{0}; uint32_t position{0};
explicit ArrayDesc(Shape shape, Dtype dtype); explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
explicit ArrayDesc( explicit ArrayDesc(
Shape shape, std::vector<int> shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs); std::vector<array> inputs);
@@ -496,14 +495,14 @@ class array {
template <typename T> template <typename T>
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */) array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
init(&val); init(&val);
} }
template <typename It> template <typename It>
array::array( array::array(
It data, It data,
Shape shape, std::vector<int> shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) : Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
init(data); init(data);
@@ -514,7 +513,7 @@ array::array(
std::initializer_list<T> data, std::initializer_list<T> data,
Dtype dtype /* = TypeToDtype<T>() */) Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>( : array_desc_(std::make_shared<ArrayDesc>(
Shape{static_cast<ShapeElem>(data.size())}, std::vector<int>{static_cast<int>(data.size())},
dtype)) { dtype)) {
init(data.begin()); init(data.begin());
} }
@@ -522,7 +521,7 @@ array::array(
template <typename T> template <typename T>
array::array( array::array(
std::initializer_list<T> data, std::initializer_list<T> data,
Shape shape, std::vector<int> shape,
Dtype dtype /* = TypeToDtype<T>() */) Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (data.size() != size()) { if (data.size() != size()) {
@@ -591,9 +590,6 @@ void array::init(It src) {
case float32: case float32:
std::copy(src, src + size(), data<float>()); std::copy(src, src + size(), data<float>());
break; break;
case float64:
std::copy(src, src + size(), data<double>());
break;
case bfloat16: case bfloat16:
std::copy(src, src + size(), data<bfloat16_t>()); std::copy(src, src + size(), data<bfloat16_t>());
break; break;

View File

@@ -0,0 +1,8 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)

View File

@@ -0,0 +1,20 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
// TODO: Add accelerate based optimizations for CPU conv
}
} // namespace mlx::core

View File

@@ -0,0 +1,253 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <Accelerate/Accelerate.h>
#include "mlx/backend/accelerate/utils.h"
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
std::tuple<bool, size_t, array> check_transpose(const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
}
inline void matmul_cblas_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_cblas_general(a_pre, b_pre, out);
}
inline void matmul_bnns_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
// TODO: Update to utilize BNNS broadcasting
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,
/* bool transA = */ a_transposed,
/* bool transB = */ b_transposed,
/* bool quadratic = */ false,
/* bool a_is_weights = */ false,
/* bool b_is_weights = */ false,
/* BNNSNDArrayDescriptor iA_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, lda, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
/* BNNSNDArrayDescriptor iB_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, ldb, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
/* BNNSNDArrayDescriptor o_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{N, M, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, N, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
};
auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
BNNSFilterApplyTwoInput(
bnns_filter,
a.data<uint8_t>() +
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
b.data<uint8_t>() +
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
out.data<uint8_t>() + M * N * i * out.itemsize());
}
BNNSFilterDestroy(bnns_filter);
}
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
// TODO: Update to utilize BNNS broadcasting
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_bnns_general(a_pre, b_pre, out);
}
template <typename T>
inline void mask_matrix(
T* data,
const bool* mask,
int tile_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str) {
int tX = (X + tile_size - 1) / tile_size;
int tY = (Y + tile_size - 1) / tile_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
if (!do_mask) {
int loc_x = i * tile_size;
int loc_y = j * tile_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
int size_x = std::min(tile_size, X - loc_x);
int size_y = std::min(tile_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
}
}
}
}
}
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() == float32) {
return matmul_cblas(inputs[0], inputs[1], out);
}
return matmul_bnns(inputs[0], inputs[1], out);
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy(c, out, ctype);
if (out.dtype() == float32) {
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
}
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core

View File

@@ -0,0 +1,601 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include <Accelerate/Accelerate.h>
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/unary.h"
#include "mlx/primitives.h"
#define DEFAULT(primitive) \
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
primitive::eval(inputs, out); \
}
#define DEFAULT_MULTI(primitive) \
void primitive::eval_cpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
primitive::eval(inputs, outputs); \
}
namespace mlx::core {
// Use the default implementation for the following primitives
DEFAULT(Arange)
DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else if (in.dtype() == int32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
} else {
eval(inputs, out);
}
}
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
a,
b,
out,
[](auto x, auto y) { return x + y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary_op<int>(
a,
b,
out,
[](auto x, auto y) { return x + y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
});
} else {
eval(inputs, out);
}
}
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvacosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvacoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvasinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvasinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvatanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
if (a.is_donatable()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
int size = a.data_size();
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvatanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().contiguous) {
// Use accelerate functions if possible
if (in.dtype() == float32 && out.dtype() == uint32) {
set_unary_output_data(in, out);
vDSP_vfixu32(
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == float32 && out.dtype() == int32) {
set_unary_output_data(in, out);
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == uint32 && out.dtype() == float32) {
set_unary_output_data(in, out);
vDSP_vfltu32(
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
} else if (in.dtype() == int32 && out.dtype() == float32) {
set_unary_output_data(in, out);
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
}
}
eval(inputs, out);
}
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvcosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvcoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == int32) {
binary_op<int>(
a,
b,
out,
[](auto x, auto y) { return x / y; },
UseDefaultBinaryOp(),
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
});
} else if (a.dtype() == float32) {
binary_op<float>(
a,
b,
out,
[](auto x, auto y) { return x / y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else {
eval(inputs, out);
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpm1f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
assert(in.dtype() == out.dtype());
if (in.data_size() == 1 && out.dtype() == float32) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
vDSP_vfill(in.data<float>(), out.data<float>(), 1, out.size());
} else {
eval(inputs, out);
}
}
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
switch (base_) {
case Base::e:
vvlogf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
break;
case Base::two:
vvlog2f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
break;
case Base::ten:
vvlog10f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
break;
}
} else {
eval(inputs, out);
}
}
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
a,
b,
out,
[](auto x, auto y) { return x * y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else {
eval(inputs, out);
}
}
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else {
eval(inputs, out);
}
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
int size = a.size();
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (reduce_type_ == Scan::Sum && out.dtype() == float32 &&
in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int stride = in.shape(axis_);
int count = in.size() / stride;
const float* input = in.data<float>();
float* output = out.data<float>();
float s = 1.0;
if (!reverse_) {
for (int i = 0; i < count; i++) {
vDSP_vrsum(input - 1, 1, &s, output, 1, stride);
input += stride;
output += stride;
}
} else {
for (int i = 0; i < count; i++) {
input += stride - 1;
output += stride - 1;
vDSP_vrsum(input + 1, -1, &s, output, -1, stride);
input++;
output++;
}
}
} else {
eval(inputs, out);
}
}
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvsinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvsinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
eval(inputs, out);
}
}
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
if (recip_) {
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
} else {
vvsqrtf(out.data<float>(), in.data<float>(), &size);
}
} else {
eval(inputs, out);
}
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
a,
b,
out,
[](auto x, auto y) { return x - y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
float minus_1 = -1;
vDSP_vsmsa(
(const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
float val = -(*s);
vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary_op<int>(
a,
b,
out,
[](auto x, auto y) { return x - y; },
UseDefaultBinaryOp(),
[](const auto* vec, const auto* s, auto* o, auto n) {
int val = -(*s);
vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n);
},
UseDefaultBinaryOp());
} else {
eval(inputs, out);
}
}
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvtanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvtanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,117 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <simd/vector.h>
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void _qmm_t_4_64(
float* result,
const float* x,
const uint32_t* w,
const float* scales,
const float* biases,
int M,
int N,
int K,
int B,
bool batched_w) {
constexpr int bits = 4;
constexpr int group_size = 64;
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
int w_els = N * K / pack_factor;
int g_els = w_els * pack_factor / group_size;
for (int i = 0; i < B; i++) {
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
}
}
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
sum += (*x_local) * wf;
x_local++;
}
}
*result = simd_reduce_add(sum);
result++;
}
x += K;
}
if (batched_w) {
w += w_els;
scales += g_els;
biases += g_els;
}
}
}
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x = inputs[0];
auto& w = inputs[1];
auto& scales = inputs[2];
auto& biases = inputs[3];
bool condition =
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
scales.flags().row_contiguous && biases.flags().row_contiguous &&
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
if (condition) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int B = x.size() / K / M;
bool batched_w = w.ndim() > 2;
_qmm_t_4_64(
out.data<float>(),
x.data<float>(),
w.data<uint32_t>(),
scales.data<float>(),
biases.data<float>(),
M,
N,
K,
B,
batched_w);
} else {
eval(inputs, out);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,139 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, typename VT>
struct MinReduction {
T operator()(const T& a, const T& b) {
return std::min(a, b);
}
VT operator()(VT a, VT b) {
return simd_min(a, b);
}
};
template <typename T, typename VT>
struct MaxReduction {
T operator()(const T& a, const T& b) {
return std::max(a, b);
}
VT operator()(VT a, VT b) {
return simd_max(a, b);
}
};
template <typename T, typename VT>
struct SumReduction {
T operator()(const T& a, const T& b) {
return a + b;
}
VT operator()(VT a, VT b) {
return a + b;
}
};
template <typename T, typename VT, int N, typename Reduction>
struct StridedReduce {
void operator()(const T* x, T* accum, int size, size_t stride) {
Reduction op;
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = op((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = op(*a, *x);
a++;
x++;
}
}
}
};
} // namespace
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32) {
if (reduce_type_ == Reduce::Sum) {
reduction_op<float, float>(
in,
out,
axes_,
0,
StridedReduce<
float,
simd_float16,
16,
SumReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) {
float acc;
vDSP_sve((const float*)x, 1, &acc, size);
(*accum) += acc;
},
[](auto* accum, auto x) { *accum += x; });
return;
} else if (reduce_type_ == Reduce::Max) {
reduction_op<float, float>(
in,
out,
axes_,
-std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MaxReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) {
float max;
vDSP_maxv((const float*)x, 1, &max, size);
(*accum) = (*accum < max) ? max : *accum;
},
[](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; });
return;
} else if (reduce_type_ == Reduce::Min) {
reduction_op<float, float>(
in,
out,
axes_,
std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MinReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) {
float min;
vDSP_minv((const float*)x, 1, &min, size);
(*accum) = (*accum > min) ? min : *accum;
},
[](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; });
return;
}
}
// TODO: Add integer addition and min/max using the templates above and
// simd_int16 and friends.
eval(inputs, out);
}
} // namespace mlx::core

View File

@@ -0,0 +1,393 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <limits>
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <arm_neon.h>
#endif
#include <simd/math.h>
#include <simd/vector.h>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
/**
* Compute exp(x) in an optimizer friendly way as follows:
*
* First change the problem to computing 2**y where y = x / ln(2).
*
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
* `ipart` and y2 is fractional part. For the integer part we perform bit
* shifting and for the fractional part we use a polynomial approximation.
*
* The algorithm and constants of the polynomial taken from
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
* from Cephes math library.
*
* Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0.
*/
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
auto x = x_init * 1.442695; // multiply with log_2(e)
simd_float16 ipart, fpart;
simd_int16 epart;
x = simd_clamp(x, -80, 80);
ipart = simd::floor(x + 0.5);
fpart = x - ipart;
x = 1.535336188319500e-4f;
x = x * fpart + 1.339887440266574e-3f;
x = x * fpart + 9.618437357674640e-3f;
x = x * fpart + 5.550332471162809e-2f;
x = x * fpart + 2.402264791363012e-1f;
x = x * fpart + 6.931472028550421e-1f;
x = x * fpart + 1.000000000000000f;
// generate 2**ipart in the floating point representation using integer
// bitshifting
epart = (simd_int(ipart) + 127) << 23;
// Avoid supressing NaNs
simd_int16 eq = (x_init == x_init);
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/**
* The ARM neon equivalent of the fast exp above.
*/
inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
float16x8_t fpart = vsubq_f16(x, ipart);
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
// generate 2**ipart in the floating point representation using integer
// bitshifting
int16x8_t epart = vcvtq_s16_f16(ipart);
epart = vaddq_s16(epart, vdupq_n_s16(15));
epart = vshlq_n_s16(epart, 10);
return vmulq_f16(vreinterpretq_f16_s16(epart), x);
}
/**
* Implementation of folding maximum for ARM neon. This should possibly be
* refactored out of softmax.cpp at some point.
*/
inline float16_t neon_reduce_max(float16x8_t x) {
float16x4_t y;
y = vpmax_f16(vget_low_f16(x), vget_high_f16(x));
y = vpmax_f16(y, y);
y = vpmax_f16(y, y);
return vget_lane_f16(y, 0);
}
/**
* Implementation of folding sum for ARM neon. This should possibly be
* refactored out of softmax.cpp at some point.
*/
inline float16_t neon_reduce_add(float16x8_t x) {
float16x4_t y;
float16x4_t zero = vdup_n_f16(0);
y = vpadd_f16(vget_low_f16(x), vget_high_f16(x));
y = vpadd_f16(y, zero);
y = vpadd_f16(y, zero);
return vget_lane_f16(y, 0);
}
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
return vdupq_n_f16(a);
}
VT load(const T* a) {
return vld1q_f16(a);
}
void store(T* dst, VT x) {
vst1q_f16(dst, x);
}
VT max(VT a, VT b) {
return vmaxq_f16(a, b);
}
VT exp(VT x) {
return neon_fast_exp(x);
}
VT add(VT a, VT b) {
return vaddq_f16(a, b);
}
VT sub(VT a, T b) {
return vsubq_f16(a, vdupq_n_f16(b));
}
VT mul(VT a, VT b) {
return vmulq_f16(a, b);
}
VT mul(VT a, T b) {
return vmulq_f16(a, vdupq_n_f16(b));
}
T reduce_max(VT x) {
return neon_reduce_max(x);
}
T reduce_add(VT x) {
return neon_reduce_add(x);
}
};
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
}
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
size_t s = M;
while (s >= N) {
VT vals;
if constexpr (std::is_same<T, AccT>::value) {
vals = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vals[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vmaximum = ops.max(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = ops.reduce_max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
VT vnormalizer = ops.init(0.0);
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
VT vexp;
if constexpr (std::is_same<T, AccT>::value) {
vexp = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vexp = ops.exp(ops.sub(vexp, maximum));
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, vexp);
}
vnormalizer = ops.add(vnormalizer, vexp);
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = ops.reduce_add(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if (std::is_same<T, AccT>::value) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
} else {
VT vexp;
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
for (int i = 0; i < N; ++i) {
current_out_ptr[i] = vexp[i];
}
current_in_ptr += N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
current_out_ptr++;
}
}
}
} // namespace
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General);
return x_copy;
}
};
array in = check_input(std::move(inputs[0]));
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::invalid_argument(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<
float,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
break;
case float16:
if (precise_) {
softmax<
float16_t,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
eval(inputs, out); // Redirect to common backend for consistency
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
}
break;
case bfloat16:
eval(inputs, out);
break;
case complex64:
eval(inputs, out);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,28 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <Accelerate/Accelerate.h>
#include "mlx/dtype.h"
namespace mlx::core {
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
uint32_t size_bits = size_of(mlx_dtype) * 8;
switch (kindof(mlx_dtype)) {
case Dtype::Kind::b:
return BNNSDataTypeBoolean;
case Dtype::Kind::u:
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
case Dtype::Kind::i:
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
case Dtype::Kind::f:
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
case Dtype::Kind::V:
return BNNSDataTypeBFloat16;
case Dtype::Kind::c:
throw std::invalid_argument("BNNS does not support complex types");
}
}
} // namespace mlx::core

View File

@@ -1,8 +1,62 @@
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE)
else()
set(COMPILER ${CMAKE_CXX_COMPILER})
endif()
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${PROJECT_SOURCE_DIR} ${CLANG}
DEPENDS make_compiled_preamble.sh
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h)
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
add_dependencies(mlx cpu_compiled_preamble)
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if(IOS)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
endif()

View File

@@ -0,0 +1,74 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
namespace mlx::core {
namespace {
template <typename T>
void arange(T start, T next, array& out, size_t size) {
auto ptr = out.data<T>();
auto step_size = next - start;
for (int i = 0; i < size; ++i) {
ptr[i] = start;
start += step_size;
}
}
} // namespace
void arange(
const std::vector<array>& inputs,
array& out,
double start,
double step) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
break;
case uint8:
arange<uint8_t>(start, start + step, out, out.size());
break;
case uint16:
arange<uint16_t>(start, start + step, out, out.size());
break;
case uint32:
arange<uint32_t>(start, start + step, out, out.size());
break;
case uint64:
arange<uint64_t>(start, start + step, out, out.size());
break;
case int8:
arange<int8_t>(start, start + step, out, out.size());
break;
case int16:
arange<int16_t>(start, start + step, out, out.size());
break;
case int32:
arange<int32_t>(start, start + step, out, out.size());
break;
case int64:
arange<int64_t>(start, start + step, out, out.size());
break;
case float16:
arange<float16_t>(start, start + step, out, out.size());
break;
case float32:
arange<float>(start, start + step, out, out.size());
break;
case bfloat16:
arange<bfloat16_t>(start, start + step, out, out.size());
break;
case complex64:
arange<complex64_t>(start, start + step, out, out.size());
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,112 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/primitives.h"
#include "utils.h"
namespace mlx::core {
namespace {
template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
std::vector<size_t> strides = in.strides();
std::vector<int> shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
for (uint32_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto in_ptr = in.data<InT>() + loc;
uint32_t ind_v = 0;
InT v = (*in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
op(j, (*in_ptr), &ind_v, &v);
}
out.data<uint32_t>()[i] = ind_v;
}
}
template <typename InT>
void arg_reduce_dispatch(
const array& in,
array& out,
ArgReduce::ReduceType rtype,
int axis) {
switch (rtype) {
case ArgReduce::ArgMin: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
if (x < (*y)) {
(*y) = x;
(*ind_y) = ind_x;
}
};
arg_reduce<InT>(in, out, op, axis);
break;
}
case ArgReduce::ArgMax: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
if (x > (*y)) {
(*y) = x;
(*ind_y) = ind_x;
}
};
arg_reduce<InT>(in, out, op, axis);
break;
}
}
}
} // namespace
void ArgReduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (in.dtype()) {
case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break;
case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break;
case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break;
case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break;
case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break;
case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break;
case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break;
case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break;
case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break;
case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break;
case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break;
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,331 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <cmath>
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/binary_two.h"
#include "mlx/backend/common/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
template <typename T, typename U, typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
DefaultScalarVector<T, U, Op> opsv(op);
DefaultVectorScalar<T, U, Op> opvs(op);
DefaultVectorVector<T, U, Op> opvv(op);
binary_op<T, U>(a, b, out, op, opsv, opvs, opvv);
}
template <typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
switch (a.dtype()) {
case bool_:
comparison_op<bool, bool>(a, b, out, op);
break;
case uint8:
comparison_op<uint8_t, bool>(a, b, out, op);
break;
case uint16:
comparison_op<uint16_t, bool>(a, b, out, op);
break;
case uint32:
comparison_op<uint32_t, bool>(a, b, out, op);
break;
case uint64:
comparison_op<uint64_t, bool>(a, b, out, op);
break;
case int8:
comparison_op<int8_t, bool>(a, b, out, op);
break;
case int16:
comparison_op<int16_t, bool>(a, b, out, op);
break;
case int32:
comparison_op<int32_t, bool>(a, b, out, op);
break;
case int64:
comparison_op<int64_t, bool>(a, b, out, op);
break;
case float16:
comparison_op<float16_t, bool>(a, b, out, op);
break;
case float32:
comparison_op<float, bool>(a, b, out, op);
break;
case bfloat16:
comparison_op<bfloat16_t, bool>(a, b, out, op);
break;
case complex64:
comparison_op<complex64_t, bool>(a, b, out, op);
break;
}
}
} // namespace
void Add::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Add());
}
void DivMod::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto integral_op = [](auto x, auto y) {
return std::make_pair(x / y, x % y);
};
auto float_op = [](auto x, auto y) {
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
};
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, integral_op);
case uint8:
binary_op<uint8_t>(a, b, outputs, integral_op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, integral_op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, integral_op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, integral_op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, integral_op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, integral_op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, integral_op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, integral_op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, float_op);
break;
case float32:
binary_op<float>(a, b, outputs, float_op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, float_op);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
}
void Divide::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Divide());
}
void Remainder::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Remainder());
}
void Equal::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (equal_nan_) {
comparison_op(inputs[0], inputs[1], out, detail::NaNEqual());
} else {
comparison_op(inputs[0], inputs[1], out, detail::Equal());
}
}
void Greater::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater());
}
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
}
void Less::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less());
}
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
}
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with"
" non floating point type.");
}
}
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
}
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
}
void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Maximum());
}
void Minimum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Minimum());
}
void Multiply::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Multiply());
}
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
}
void Power::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Power());
}
void Subtract::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Subtract());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd());
break;
case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr());
break;
case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor());
break;
case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift());
break;
case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift());
break;
}
}
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[arctan2] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[arctan2] Cannot compute inverse tangent for arrays"
" with non floating point type.");
}
}
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include <cassert>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
@@ -8,6 +9,8 @@
namespace mlx::core { namespace mlx::core {
namespace {
enum class BinaryOpType { enum class BinaryOpType {
ScalarScalar, ScalarScalar,
ScalarVector, ScalarVector,
@@ -16,7 +19,7 @@ enum class BinaryOpType {
General, General,
}; };
inline BinaryOpType get_binary_op_type(const array& a, const array& b) { BinaryOpType get_binary_op_type(const array& a, const array& b) {
BinaryOpType bopt; BinaryOpType bopt;
if (a.data_size() == 1 && b.data_size() == 1) { if (a.data_size() == 1 && b.data_size() == 1) {
bopt = BinaryOpType::ScalarScalar; bopt = BinaryOpType::ScalarScalar;
@@ -25,8 +28,8 @@ inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
} else if (b.data_size() == 1 && a.flags().contiguous) { } else if (b.data_size() == 1 && a.flags().contiguous) {
bopt = BinaryOpType::VectorScalar; bopt = BinaryOpType::VectorScalar;
} else if ( } else if (
(a.flags().row_contiguous && b.flags().row_contiguous) || a.flags().row_contiguous && b.flags().row_contiguous ||
(a.flags().col_contiguous && b.flags().col_contiguous)) { a.flags().col_contiguous && b.flags().col_contiguous) {
bopt = BinaryOpType::VectorVector; bopt = BinaryOpType::VectorVector;
} else { } else {
bopt = BinaryOpType::General; bopt = BinaryOpType::General;
@@ -34,11 +37,12 @@ inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
return bopt; return bopt;
} }
inline void set_binary_op_output_data( void set_binary_op_output_data(
const array& a, const array& a,
const array& b, const array& b,
array& out, array& out,
BinaryOpType bopt) { BinaryOpType bopt,
bool donate_with_move = false) {
bool b_donatable = is_donatable(b, out); bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out); bool a_donatable = is_donatable(a, out);
switch (bopt) { switch (bopt) {
@@ -48,7 +52,11 @@ inline void set_binary_op_output_data(
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
if (b_donatable) { if (b_donatable) {
out.copy_shared_buffer(b); if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()), allocator::malloc_or_wait(b.data_size() * out.itemsize()),
@@ -59,7 +67,11 @@ inline void set_binary_op_output_data(
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
if (a_donatable) { if (a_donatable) {
out.copy_shared_buffer(a); if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()), allocator::malloc_or_wait(a.data_size() * out.itemsize()),
@@ -70,9 +82,17 @@ inline void set_binary_op_output_data(
break; break;
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
if (a_donatable) { if (a_donatable) {
out.copy_shared_buffer(a); if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (b_donatable) { } else if (b_donatable) {
out.copy_shared_buffer(b); if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()), allocator::malloc_or_wait(a.data_size() * out.itemsize()),
@@ -83,10 +103,18 @@ inline void set_binary_op_output_data(
break; break;
case BinaryOpType::General: case BinaryOpType::General:
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) { if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
out.copy_shared_buffer(a); if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if ( } else if (
b_donatable && b.flags().row_contiguous && b.size() == out.size()) { b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b); if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
} }
@@ -94,4 +122,409 @@ inline void set_binary_op_output_data(
} }
} }
struct UseDefaultBinaryOp {};
template <typename T, typename U, typename Op>
struct DefaultVectorScalar {
Op op;
DefaultVectorScalar(Op op_) : op(op_) {}
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *b;
while (size-- > 0) {
*dst = op(*a, scalar);
dst++;
a++;
}
}
};
template <typename T, typename U, typename Op>
struct DefaultScalarVector {
Op op;
DefaultScalarVector(Op op_) : op(op_) {}
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *a;
while (size-- > 0) {
*dst = op(scalar, *b);
dst++;
b++;
}
}
};
template <typename T, typename U, typename Op>
struct DefaultVectorVector {
Op op;
DefaultVectorVector(Op op_) : op(op_) {}
void operator()(const T* a, const T* b, U* dst, int size) {
while (size-- > 0) {
*dst = op(*a, *b);
dst++;
a++;
b++;
}
}
};
template <typename T, typename U, typename Op, int D, bool Strided>
void binary_op_dims(
const T* a,
const T* b,
U* out,
Op op,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
binary_op_dims<T, U, Op, D - 1, Strided>(
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
} else {
if constexpr (Strided) {
op(a, b, out, stride_out);
} else {
*out = op(*a, *b);
}
}
out += stride_out;
a += stride_a;
b += stride_b;
}
}
template <typename T, typename U, bool Strided, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op,
int dim,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
switch (dim) {
case 1:
binary_op_dims<T, U, Op, 1, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 2:
binary_op_dims<T, U, Op, 2, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 3:
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
size_t stride = out_strides[dim - 4];
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
dim - 3);
a_it.step();
b_it.step();
}
}
template <
typename T,
typename U,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
const auto& a_strides = new_strides[0];
const auto& b_strides = new_strides[1];
const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b_strides);
auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
if (dim == 0 || strides[dim - 1] < 16) {
bopt = BinaryOpType::General;
dim = ndim;
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true>(
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true>(
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true>(
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
break;
default:
binary_op_dispatch_dims<T, U, false>(
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
break;
}
}
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv and opvs were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
}
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opvs was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
}
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
} else {
// All ops provided
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
}
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
DefaultScalarVector<T, T, Op> opsv(op);
DefaultVectorScalar<T, T, Op> opvs(op);
DefaultVectorVector<T, T, Op> opvv(op);
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
}
template <typename... Ops>
void binary(const array& a, const array& b, array& out, Ops... ops) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, ops...);
break;
case uint8:
binary_op<uint8_t>(a, b, out, ops...);
break;
case uint16:
binary_op<uint16_t>(a, b, out, ops...);
break;
case uint32:
binary_op<uint32_t>(a, b, out, ops...);
break;
case uint64:
binary_op<uint64_t>(a, b, out, ops...);
break;
case int8:
binary_op<int8_t>(a, b, out, ops...);
break;
case int16:
binary_op<int16_t>(a, b, out, ops...);
break;
case int32:
binary_op<int32_t>(a, b, out, ops...);
break;
case int64:
binary_op<int64_t>(a, b, out, ops...);
break;
case float16:
binary_op<float16_t>(a, b, out, ops...);
break;
case float32:
binary_op<float>(a, b, out, ops...);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, ops...);
break;
case complex64:
binary_op<complex64_t>(a, b, out, ops...);
break;
}
}
} // namespace
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,8 +2,8 @@
#pragma once #pragma once
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary.h"
namespace mlx::core { namespace mlx::core {
@@ -16,10 +16,10 @@ void binary_op_dims(
U* out_a, U* out_a,
U* out_b, U* out_b,
Op op, Op op,
const Shape& shape, const std::vector<int>& shape,
const Strides& a_strides, const std::vector<size_t>& a_strides,
const Strides& b_strides, const std::vector<size_t>& b_strides,
const Strides& out_strides, const std::vector<size_t>& out_strides,
int axis) { int axis) {
auto stride_a = a_strides[axis]; auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis]; auto stride_b = b_strides[axis];
@@ -58,14 +58,14 @@ void binary_op_dispatch_dims(
Op op) { Op op) {
auto [shape, strides] = collapse_contiguous_dims( auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()}); a.shape(), {a.strides(), b.strides(), out_a.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
const T* a_ptr = a.data<T>(); const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>(); const T* b_ptr = b.data<T>();
U* out_a_ptr = out_a.data<U>(); U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>(); U* out_b_ptr = out_b.data<U>();
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
int ndim = shape.size(); int ndim = shape.size();
switch (ndim) { switch (ndim) {
case 1: case 1:
@@ -96,9 +96,9 @@ void binary_op_dispatch_dims(
return; return;
} }
ContiguousIterator a_it(shape, a_strides, ndim - 2); ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2); ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3]; size_t stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) { for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>( binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc, a_ptr + a_it.loc,
@@ -120,10 +120,14 @@ template <typename T, typename U = T, typename Op>
void binary_op( void binary_op(
const array& a, const array& a,
const array& b, const array& b,
array& out_a, std::vector<array>& outputs,
array& out_b, Op op) {
Op op, auto bopt = get_binary_op_type(a, b);
BinaryOpType bopt) { auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once // The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op); binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
@@ -137,14 +141,14 @@ void binary_op(
if (bopt == BinaryOpType::ScalarScalar) { if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
} else if (bopt == BinaryOpType::ScalarVector) { } else if (bopt == BinaryOpType::ScalarVector) {
for (size_t i = 0; i < b.data_size(); ++i) { for (size_t i = 0; i < b.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++; out_a_ptr++;
out_b_ptr++; out_b_ptr++;
b_ptr++; b_ptr++;
} }
} else if (bopt == BinaryOpType::VectorScalar) { } else if (bopt == BinaryOpType::VectorScalar) {
for (size_t i = 0; i < a.data_size(); ++i) { for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++; out_a_ptr++;
out_b_ptr++; out_b_ptr++;
@@ -161,6 +165,55 @@ void binary_op(
} }
} }
template <typename Op>
void binary(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, op);
break;
case uint8:
binary_op<uint8_t>(a, b, outputs, op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, op);
break;
case float32:
binary_op<float>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op);
break;
case complex64:
binary_op<complex64_t>(a, b, outputs, op);
break;
}
}
} // namespace } // namespace
} // namespace mlx::core } // namespace mlx::core

View File

@@ -0,0 +1,74 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
// (A)ᵀ = A
// and that a column-major lower triangular matrix is a row-major upper
// triangular matrix, so uplo is the opposite of what we would expect from
// upper
char uplo = (upper) ? 'L' : 'U';
// The decomposition is computed in place, so just copy the input to the
// output.
copy(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
float* matrix = factor.data<float>();
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info;
MLX_LAPACK_FUNC(spotrf)
(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[cholesky] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
}
matrix += N;
}
}
}
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Cholesky::eval] only supports float32.");
}
cholesky_impl(inputs[0], output, upper_);
}
} // namespace mlx::core

View File

@@ -42,12 +42,14 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
} }
void broadcast(const array& in, array& out) { void Broadcast::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(nullptr);
return; return;
} }
Strides strides(out.ndim(), 0); std::vector<size_t> strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim(); int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) { for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
@@ -59,14 +61,6 @@ void broadcast(const array& in, array& out) {
out.copy_shared_buffer(in, strides, flags, in.data_size()); out.copy_shared_buffer(in, strides, flags, in.data_size());
} }
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out);
}
void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out);
}
void Copy::eval(const std::vector<array>& inputs, array& out) { void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]); out.copy_shared_buffer(inputs[0]);
@@ -91,16 +85,6 @@ void Depends::eval(
} }
} }
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto strides = in.strides();
for (auto ax : axes_) {
strides.insert(strides.begin() + ax, 1);
}
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) { void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -151,16 +135,15 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
case bfloat16: case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel); *out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break; break;
case float64:
*out.data<double>() = static_cast<double>(numel);
break;
case complex64: case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel); *out.data<complex64_t>() = static_cast<complex64_t>(numel);
break; break;
} }
} }
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) { std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays or row contiguous arrays // Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) { if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()}; return {false, out.strides()};
@@ -168,7 +151,8 @@ std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
// Special case for scalars // Special case for scalars
if (in.ndim() == 0) { if (in.ndim() == 0) {
return {false, Strides(out.ndim(), 0)}; std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
} }
// Firstly let's collapse all the contiguous dimensions of the input // Firstly let's collapse all the contiguous dimensions of the input
@@ -176,7 +160,7 @@ std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
// If shapes fit exactly in the contiguous dims then no copy is necessary so // If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check. // let's check.
Strides out_strides; std::vector<size_t> out_strides;
bool copy_necessary = false; bool copy_necessary = false;
int j = 0; int j = 0;
for (int i = 0; i < out.ndim(); i++) { for (int i = 0; i < out.ndim(); i++) {
@@ -197,9 +181,9 @@ std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
return {copy_necessary, out_strides}; return {copy_necessary, out_strides};
} }
void shared_buffer_reshape( void Reshape::shared_buffer_reshape(
const array& in, const array& in,
const Strides& out_strides, const std::vector<size_t>& out_strides,
array& out) { array& out) {
auto flags = in.flags(); auto flags = in.flags();
if (flags.row_contiguous) { if (flags.row_contiguous) {
@@ -265,18 +249,16 @@ void Split::eval(
} }
} }
void Squeeze::eval(const std::vector<array>& inputs, array& out) { std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
assert(inputs.size() == 1); const array& in) {
const auto& in = inputs[0]; int64_t data_offset = 0;
Strides strides; std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0, j = 0; i < in.ndim(); ++i) { for (int i = 0; i < in.ndim(); ++i) {
if (j < axes_.size() && i == axes_[j]) { data_offset += start_indices_[i] * in.strides()[i];
j++; inp_strides[i] = in.strides()[i] * strides_[i];
} else {
strides.push_back(in.strides(i));
}
} }
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
return std::make_tuple(data_offset, inp_strides);
} }
void StopGradient::eval(const std::vector<array>& inputs, array& out) { void StopGradient::eval(const std::vector<array>& inputs, array& out) {
@@ -286,7 +268,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
void Transpose::eval(const std::vector<array>& inputs, array& out) { void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
Strides out_strides(out.ndim()); std::vector<size_t> out_strides(out.ndim());
auto& in = inputs[0]; auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) { for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]]; out_strides[ax] = in.strides()[axes_[ax]];
@@ -303,8 +285,8 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
// true, they stay true) // true, they stay true)
auto flags = in.flags(); auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) { if (flags.contiguous && in.data_size() == in.size()) {
int64_t f_stride = 1; size_t f_stride = 1;
int64_t b_stride = 1; size_t b_stride = 1;
flags.col_contiguous = true; flags.col_contiguous = true;
flags.row_contiguous = true; flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) { for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {

View File

@@ -130,7 +130,7 @@ std::string build_lib_name(
bool compiled_check_contiguity( bool compiled_check_contiguity(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const Shape& shape) { const std::vector<int>& shape) {
bool contiguous = true; bool contiguous = true;
bool all_contig = true; bool all_contig = true;
bool all_row_contig = true; bool all_row_contig = true;
@@ -161,10 +161,11 @@ void compiled_allocate_outputs(
std::vector<array>& outputs, std::vector<array>& outputs,
const std::vector<array>& inputs_, const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_, const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous) { bool contiguous,
bool move_buffers /* = false */) {
if (contiguous) { if (contiguous) {
int o = 0; int o = 0;
Strides strides; std::vector<size_t> strides;
size_t data_size; size_t data_size;
array::Flags flags; array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
@@ -177,7 +178,11 @@ void compiled_allocate_outputs(
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in); if (move_buffers) {
outputs[o++].move_shared_buffer(in);
} else {
outputs[o++].copy_shared_buffer(in);
}
} }
// Get representative input flags to properly set non-donated outputs // Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) { if (strides.empty() && in.size() == outputs[0].size()) {
@@ -205,8 +210,13 @@ void compiled_allocate_outputs(
if (in.flags().row_contiguous && in.size() == outputs[o].size() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o].copy_shared_buffer( if (move_buffers) {
in, outputs[o].strides(), in.flags(), in.data_size()); outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
} else {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
}
o++; o++;
} }
} }

View File

@@ -11,7 +11,9 @@
namespace mlx::core { namespace mlx::core {
inline bool is_static_cast(const Primitive& p) { inline bool is_static_cast(const Primitive& p) {
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); return (
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
} }
std::string build_lib_name( std::string build_lib_name(
@@ -54,7 +56,7 @@ inline bool is_scalar(const array& x) {
// Check if we can use a contiguous operation given inputs and the output shape // Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity( bool compiled_check_contiguity(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const Shape& shape); const std::vector<int>& shape);
// Allocate space for the outputs possibly with input donation // Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs( void compiled_allocate_outputs(
@@ -62,6 +64,7 @@ void compiled_allocate_outputs(
std::vector<array>& outputs, std::vector<array>& outputs,
const std::vector<array>& inputs_, const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_, const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous); bool contiguous,
bool move_buffers = false);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -7,12 +7,8 @@
#include <mutex> #include <mutex>
#include <shared_mutex> #include <shared_mutex>
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/cpu/compiled_preamble.h" #include "mlx/backend/common/compiled_preamble.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/graph_utils.h" #include "mlx/graph_utils.h"
@@ -48,9 +44,12 @@ namespace detail {
bool compile_available_for_device(const Device& device) { bool compile_available_for_device(const Device& device) {
return true; return true;
} }
} // namespace detail } // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function // Return a pointer to a compiled function
void* compile( void* compile(
const std::string& kernel_name, const std::string& kernel_name,
@@ -69,30 +68,24 @@ void* compile(
std::string source_code = source_builder(); std::string source_code = source_builder();
std::string kernel_file_name; std::string kernel_file_name;
// Deal with long kernel names. Maximum length for filename on macOS is 255 // Deal with long kernel names. Maximum length for files on macOS is 255
// characters, and on Windows the maximum length for whole path is 260. Clip // characters. Clip file name with a little extra room and append a 16
// file name with a little extra room and append a 16 character hash. // character hash.
#ifdef _WIN32
constexpr int max_file_name_length = 140;
#else
constexpr int max_file_name_length = 245; constexpr int max_file_name_length = 245;
#endif
if (kernel_name.size() > max_file_name_length) { if (kernel_name.size() > max_file_name_length) {
std::ostringstream file_name; std::ostringstream file_name;
file_name file_name
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16); << std::string_view(kernel_name).substr(0, max_file_name_length - 16);
auto file_id = auto file_id = std::hash<std::string>{}(kernel_name);
std::hash<std::string>{}(kernel_name.substr(max_file_name_length - 16));
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec; file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
kernel_file_name = file_name.str(); kernel_file_name = file_name.str();
} else { } else {
kernel_file_name = kernel_name; kernel_file_name = kernel_name;
} }
auto output_dir = std::filesystem::temp_directory_path(); std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_file_name << ".so";
std::string shared_lib_name = "lib" + kernel_file_name + ".so"; auto shared_lib_path = get_temp_file(shared_lib_name.str());
auto shared_lib_path = (output_dir / shared_lib_name).string();
bool lib_exists = false; bool lib_exists = false;
{ {
std::ifstream f(shared_lib_path.c_str()); std::ifstream f(shared_lib_path.c_str());
@@ -101,21 +94,24 @@ void* compile(
if (!lib_exists) { if (!lib_exists) {
// Open source file and write source code to it // Open source file and write source code to it
std::string source_file_name = kernel_file_name + ".cpp"; std::ostringstream source_file_name;
auto source_file_path = (output_dir / source_file_name).string(); source_file_name << kernel_file_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str());
std::ofstream source_file(source_file_path); std::ofstream source_file(source_file_path);
source_file << source_code; source_file << source_code;
source_file.close(); source_file.close();
try { std::ostringstream build_command;
JitCompiler::exec(JitCompiler::build_command( build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
output_dir, source_file_name, shared_lib_name)); << source_file_path << "' -o '" << shared_lib_path << "'";
} catch (const std::exception& error) { std::string build_command_str = build_command.str();
throw std::runtime_error(fmt::format( auto return_code = system(build_command_str.c_str());
"[Compile::eval_cpu] Failed to compile function {0}: {1}", if (return_code) {
kernel_name, std::ostringstream msg;
error.what())); msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
<< " with error code " << return_code << "." << std::endl;
throw std::runtime_error(msg.str());
} }
} }
@@ -155,11 +151,6 @@ inline void build_kernel(
NodeNamer namer; NodeNamer namer;
#ifdef _MSC_VER
// Export the symbol
os << "__declspec(dllexport) ";
#endif
// Start the kernel // Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl; os << "void " << kernel_name << "(void** args) {" << std::endl;
@@ -288,8 +279,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using // Figure out which kernel we are using
auto& shape = outputs[0].shape(); auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape); bool contiguous = compiled_check_contiguity(inputs, shape);
auto& encoder = cpu::get_command_encoder(stream());
// Handle all broadcasting and collect function input arguments // Handle all broadcasting and collect function input arguments
std::vector<void*> args; std::vector<void*> args;
@@ -300,7 +290,6 @@ void Compiled::eval_cpu(
continue; continue;
} }
auto& x = inputs[i]; auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>()); args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) { if (contiguous || is_scalar(x)) {
@@ -359,25 +348,18 @@ void Compiled::eval_cpu(
}); });
compiled_allocate_outputs( compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous); inputs, outputs, inputs_, constant_ids_, contiguous, false);
for (auto& x : outputs) { for (auto& x : outputs) {
args.push_back(x.data<void>()); args.push_back(x.data<void>());
encoder.set_output_array(x);
} }
Shape out_shape;
if (!contiguous) { if (!contiguous) {
out_shape = outputs[0].shape(); args.push_back((void*)outputs[0].shape().data());
args.push_back((void*)out_shape.data());
} else { } else {
args.push_back((void*)outputs[0].data_size()); args.push_back((void*)outputs[0].data_size());
} }
auto fun = (void (*)(void**))fn_ptr; auto fun = (void (*)(void**))fn_ptr;
encoder.dispatch( fun(args.data());
[fun,
args = std::move(args),
strides = std::move(strides),
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,7 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/compile_impl.h" #include "mlx/backend/common/compiled.h"
#include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {

View File

@@ -5,8 +5,7 @@
// clang-format off // clang-format off
#include "mlx/types/half_types.h" #include "mlx/types/half_types.h"
#include "mlx/types/complex.h" #include "mlx/types/complex.h"
#include "mlx/backend/cpu/unary_ops.h" #include "mlx/backend/common/ops.h"
#include "mlx/backend/cpu/binary_ops.h"
// clang-format on // clang-format on
const char* get_kernel_preamble(); const char* get_kernel_preamble();

1184
mlx/backend/common/conv.cpp Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -3,10 +3,8 @@
#include <numeric> #include <numeric>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core { namespace mlx::core {
@@ -14,28 +12,27 @@ namespace {
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_single(const array& src, array& dst) { void copy_single(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>(); auto val = static_cast<DstT>(src.data<SrcT>()[0]);
auto dst_ptr = dst.data<DstT>(); auto dst_ptr = dst.data<DstT>();
auto size = dst.size(); for (int i = 0; i < dst.size(); ++i) {
auto val = static_cast<DstT>(src_ptr[0]); dst_ptr[i] = val;
std::fill_n(dst_ptr, size, val); }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst) { void copy_vector(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>(); auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>(); auto dst_ptr = dst.data<DstT>();
auto size = src.data_size(); std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
std::copy(src_ptr, src_ptr + size, dst_ptr);
} }
template <typename SrcT, typename DstT, int D> template <typename SrcT, typename DstT, typename StrideT, int D>
inline void copy_dims( inline void copy_dims(
const SrcT* src, const SrcT* src,
DstT* dst, DstT* dst,
const Shape& shape, const std::vector<int>& shape,
const Strides& i_strides, const std::vector<StrideT>& i_strides,
const Strides& o_strides, const std::vector<StrideT>& o_strides,
int axis) { int axis) {
auto stride_src = i_strides[axis]; auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis]; auto stride_dst = o_strides[axis];
@@ -43,7 +40,7 @@ inline void copy_dims(
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if constexpr (D > 1) { if constexpr (D > 1) {
copy_dims<SrcT, DstT, D - 1>( copy_dims<SrcT, DstT, StrideT, D - 1>(
src, dst, shape, i_strides, o_strides, axis + 1); src, dst, shape, i_strides, o_strides, axis + 1);
} else { } else {
*dst = static_cast<DstT>(*src); *dst = static_cast<DstT>(*src);
@@ -53,66 +50,45 @@ inline void copy_dims(
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT, typename StrideT>
void copy_general_general( void copy_general_general(
const array& src, const array& src,
array& dst, array& dst,
const Shape& data_shape, const std::vector<int>& data_shape,
const Strides& i_strides, const std::vector<StrideT>& i_strides,
const Strides& o_strides, const std::vector<StrideT>& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset) {
const std::optional<array>& dynamic_i_offset,
const std::optional<array>& dynamic_o_offset) {
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
auto i_offset_ptr =
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
auto o_offset_ptr =
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
auto size = src.size();
if (data_shape.empty()) { if (data_shape.empty()) {
auto val = static_cast<DstT>(*src_ptr); auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
auto dst_ptr = dst.data<DstT>() + o_offset;
*dst_ptr = val; *dst_ptr = val;
return; return;
} }
auto [shape, strides] = auto [shape, strides] = collapse_contiguous_dims(
collapse_contiguous_dims(data_shape, {i_strides, o_strides}); data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides});
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
int ndim = shape.size(); int ndim = shape.size();
if (ndim < 3) { if (ndim == 1) {
if (i_offset_ptr) { copy_dims<SrcT, DstT, StrideT, 1>(
src_ptr += i_offset_ptr[0]; src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} return;
if (o_offset_ptr) { } else if (ndim == 2) {
dst_ptr += o_offset_ptr[0]; copy_dims<SrcT, DstT, StrideT, 2>(
} src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
if (ndim == 1) { } else if (ndim == 3) {
copy_dims<SrcT, DstT, 1>( copy_dims<SrcT, DstT, StrideT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0); src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
}
return; return;
} }
if (i_offset_ptr) { ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
src_ptr += i_offset_ptr[0]; ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
} StrideT stride = std::accumulate(
if (o_offset_ptr) { shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
dst_ptr += o_offset_ptr[0]; for (StrideT elem = 0; elem < src.size(); elem += stride) {
} copy_dims<SrcT, DstT, StrideT, 3>(
ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < size; elem += stride) {
copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc, src_ptr + in.loc,
dst_ptr + out.loc, dst_ptr + out.loc,
shape, shape,
@@ -126,53 +102,39 @@ void copy_general_general(
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) { inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>( copy_general_general<SrcT, DstT, size_t>(
src, src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
dst,
src.shape(),
src.strides(),
dst.strides(),
0,
0,
std::nullopt,
std::nullopt);
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT, typename StrideT>
void copy_general( void copy_general(
const array& src, const array& src,
array& dst, array& dst,
const Shape& data_shape, const std::vector<int>& data_shape,
const Strides& i_strides, const std::vector<StrideT>& i_strides,
const Strides&, const std::vector<StrideT>&,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset) {
const std::optional<array>& dynamic_i_offset, copy_general_general<SrcT, DstT, StrideT>(
const std::optional<array>& dynamic_o_offset) {
copy_general_general<SrcT, DstT>(
src, src,
dst, dst,
data_shape, data_shape,
i_strides, i_strides,
make_contiguous_strides(data_shape), make_contiguous_strides<StrideT>(data_shape),
i_offset, i_offset,
o_offset, o_offset);
dynamic_i_offset,
dynamic_o_offset);
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) { inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT>( copy_general_general<SrcT, DstT, size_t>(
src, src,
dst, dst,
src.shape(), src.shape(),
src.strides(), src.strides(),
make_contiguous_strides(src.shape()), make_contiguous_strides<size_t>(src.shape()),
0, 0,
0, 0);
std::nullopt,
std::nullopt);
} }
template <typename SrcT, typename DstT, typename... Args> template <typename SrcT, typename DstT, typename... Args>
@@ -229,9 +191,6 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
case float32: case float32:
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...); copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float64:
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16: case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...); copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
@@ -281,9 +240,6 @@ inline void copy_inplace_dispatch(
case float32: case float32:
copy<float>(src, dst, ctype, std::forward<Args>(args)...); copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
case float64:
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16: case bfloat16:
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...); copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break; break;
@@ -295,82 +251,84 @@ inline void copy_inplace_dispatch(
} // namespace } // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) { void copy_inplace(const array& src, array& dst, CopyType ctype) {
auto& encoder = cpu::get_command_encoder(stream); copy_inplace_dispatch(src, dst, ctype);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
} }
void copy(const array& src, array& dst, CopyType ctype, Stream stream) { void copy(const array& src, array& dst, CopyType ctype) {
bool donated = set_copy_output_data(src, dst, ctype); // Allocate the output
if (donated && src.dtype() == dst.dtype()) { switch (ctype) {
// If the output has the same type as the input then there is nothing to case CopyType::Vector:
// copy, just use the buffer. if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
return; dst.copy_shared_buffer(src);
} else {
auto size = src.data_size();
dst.set_data(
allocator::malloc_or_wait(size * dst.itemsize()),
size,
src.strides(),
src.flags());
}
break;
case CopyType::Scalar:
case CopyType::General:
case CopyType::GeneralGeneral:
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
break;
} }
if (ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General; ctype = CopyType::General;
} }
copy_inplace(src, dst, ctype, stream); copy_inplace(src, dst, ctype);
} }
template <typename StrideT>
void copy_inplace( void copy_inplace(
const array& src, const array& src,
array& dst, array& dst,
const Shape& data_shape, const std::vector<int>& data_shape,
const Strides& i_strides, const std::vector<StrideT>& i_strides,
const Strides& o_strides, const std::vector<StrideT>& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype, CopyType ctype) {
Stream stream, switch (ctype) {
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */ case CopyType::General:
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) { case CopyType::GeneralGeneral:
auto& encoder = cpu::get_command_encoder(stream); copy_inplace_dispatch(
encoder.set_input_array(src); src,
encoder.set_output_array(dst); dst,
auto weak_copy_if_set = [](auto x) -> std::optional<array> { ctype,
if (x) { data_shape,
return array::unsafe_weak_copy(*x); i_strides,
} else { o_strides,
return std::nullopt; i_offset,
} o_offset);
}; break;
encoder.dispatch( case CopyType::Scalar:
[src = array::unsafe_weak_copy(src), case CopyType::Vector:
dst = array::unsafe_weak_copy(dst), copy_inplace_dispatch(src, dst, ctype);
data_shape, }
i_strides,
o_strides,
i_offset,
o_offset,
ctype,
dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
dynamic_i_offset,
dynamic_o_offset);
break;
case CopyType::Scalar:
case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype);
}
});
} }
template void copy_inplace<size_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<size_t>& i_strides,
const std::vector<size_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
template void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<int64_t>& i_strides,
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -22,25 +23,18 @@ enum class CopyType {
GeneralGeneral GeneralGeneral
}; };
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { void copy(const array& src, array& dst, CopyType ctype);
if (ctype == CopyType::Vector) { void copy_inplace(const array& src, array& dst, CopyType ctype);
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output. template <typename stride_t>
if (in.is_donatable() && in.itemsize() == out.itemsize()) { void copy_inplace(
out.copy_shared_buffer(in); const array& src,
return true; array& dst,
} else { const std::vector<int>& data_shape,
out.set_data( const std::vector<stride_t>& i_strides,
allocator::malloc_or_wait(in.data_size() * out.itemsize()), const std::vector<stride_t>& o_strides,
in.data_size(), int64_t i_offset,
in.strides(), int64_t o_offset,
in.flags()); CopyType ctype);
return false;
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return false;
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -0,0 +1,196 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
#define DEFAULT(primitive) \
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
primitive::eval(inputs, out); \
}
#define DEFAULT_MULTI(primitive) \
void primitive::eval_cpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
primitive::eval(inputs, outputs); \
}
namespace mlx::core {
DEFAULT(Abs)
DEFAULT(Add)
DEFAULT(Arange)
DEFAULT(ArcCos)
DEFAULT(ArcCosh)
DEFAULT(ArcSin)
DEFAULT(ArcSinh)
DEFAULT(ArcTan)
DEFAULT(ArcTan2)
DEFAULT(ArcTanh)
DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BlockMaskedMM)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(NumberOfElements)
DEFAULT(Remainder)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(Expm1)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Full)
DEFAULT(Gather)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(Log)
DEFAULT(Log1p)
DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(Multiply)
DEFAULT(Negative)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT(Power)
DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce)
DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Sin)
DEFAULT(Sinh)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
DEFAULT(Square)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
DEFAULT(Subtract)
DEFAULT_MULTI(SVD)
DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
namespace {
inline void matmul_common_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[Matmul::eval_cpu] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_common_general(inputs[0], inputs[1], out);
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy(c, out, ctype);
return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core

117
mlx/backend/common/eigh.cpp Normal file
View File

@@ -0,0 +1,117 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void ssyevd(
char jobz,
char uplo,
float* a,
int N,
float* w,
float* work,
int lwork,
int* iwork,
int liwork) {
int info;
MLX_LAPACK_FUNC(ssyevd)
(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ a,
/* lda = */ &N,
/* w = */ w,
/* work = */ work,
/* lwork = */ &lwork,
/* iwork = */ iwork,
/* liwork = */ &liwork,
/* info = */ &info);
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
} // namespace
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
}
auto vec_ptr = vectors.data<float>();
auto eig_ptr = values.data<float>();
char jobz = compute_eigenvectors_ ? 'V' : 'N';
auto N = a.shape(-1);
// Work query
int lwork;
int liwork;
{
float work;
int iwork;
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < a.size() / (N * N); ++i) {
ssyevd(
jobz,
uplo_[0],
vec_ptr,
N,
eig_ptr,
static_cast<float*>(work_buf.buffer.raw_ptr()),
lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
liwork);
vec_ptr += N * N;
eig_ptr += N;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,40 @@
// Copyright © 2023 Apple Inc.
#include <cmath>
namespace mlx::core {
/* Approximation to the inverse error function.
* Based on code from:
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
*/
float erfinv(float a) {
auto t = std::fma(a, 0.0f - a, 1.0f);
t = std::log(t);
float p;
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
} else { // maximum ulp error = 2.35002
p = 5.43877832e-9f; // 0x1.75c000p-28
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
}
return a * p;
}
} // namespace mlx::core

View File

@@ -0,0 +1,87 @@
// Copyright © 2023 Apple Inc.
#include <numeric>
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/allocator.h"
#include "mlx/primitives.h"
namespace mlx::core {
void FFT::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
std::vector<std::ptrdiff_t> strides_in(
in.strides().begin(), in.strides().end());
for (auto& s : strides_in) {
s *= in.itemsize();
}
std::vector<std::ptrdiff_t> strides_out(
out.strides().begin(), out.strides().end());
for (auto& s : strides_out) {
s *= out.itemsize();
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<size_t> shape;
if (out.dtype() == float32) {
shape.insert(shape.end(), out.shape().begin(), out.shape().end());
} else {
shape.insert(shape.end(), in.shape().begin(), in.shape().end());
}
float scale = 1.0f;
if (inverse_) {
size_t nelem = std::accumulate(
axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) {
return x * shape[y];
});
scale /= nelem;
}
if (in.dtype() == complex64 && out.dtype() == complex64) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::c2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
} else if (in.dtype() == float32 && out.dtype() == complex64) {
auto in_ptr = in.data<float>();
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::r2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
} else if (in.dtype() == complex64 && out.dtype() == float32) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr = out.data<float>();
pocketfft::c2r(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
} else {
throw std::runtime_error(
"[FFT] Received unexpected input and output type combination.");
}
}
} // namespace mlx::core

View File

@@ -2,19 +2,18 @@
#include <cassert> #include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/hadamard.h" #include "mlx/backend/common/hadamard.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
// n = 2^k component // n = 2^k component
template <typename T> template <typename T>
void hadamard_n(T* out, int n, int m, float scale, size_t size) { void hadamard_n(array& out, int n, int m, float scale) {
for (int b = 0; b < size / n; b++) { for (int b = 0; b < out.size() / n; b++) {
size_t loc = b * n; size_t loc = b * n;
T* data_ptr = out + loc; T* data_ptr = out.data<T>() + loc;
int h = 1; int h = 1;
int n_over_2 = n / 2; int n_over_2 = n / 2;
while (h < n) { while (h < n) {
@@ -37,7 +36,7 @@ void hadamard_n(T* out, int n, int m, float scale, size_t size) {
// m component // m component
template <typename T> template <typename T>
void hadamard_m(T* out, int n, int m, float scale, size_t size) { void hadamard_m(array& out, int n, int m, float scale) {
auto h_matrices = hadamard_matrices(); auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m]; auto& matrix = h_matrices[m];
auto start = 1; auto start = 1;
@@ -52,9 +51,9 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
end = matrix.find('\n', start); end = matrix.find('\n', start);
} }
for (int b = 0; b < size / m / n; b++) { for (int b = 0; b < out.size() / m / n; b++) {
size_t loc = b * n * m; size_t loc = b * n * m;
T* data_ptr = out + loc; T* data_ptr = out.data<T>() + loc;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
std::vector<float> out(m); std::vector<float> out(m);
for (int j = 0; j < m; j++) { for (int j = 0; j < m; j++) {
@@ -75,47 +74,34 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
} }
template <typename T> template <typename T>
void hadamard(array& out, int n, int m, float scale, Stream stream) { void hadamard(array& out, int n, int m, float scale) {
auto& encoder = cpu::get_command_encoder(stream); float n_scale = m > 1 ? 1.0 : scale;
encoder.set_output_array(out); hadamard_n<T>(out, n, m, n_scale);
auto out_ptr = out.data<T>(); if (m > 1) {
encoder.dispatch([out_ptr, size = out.size(), n, m, scale]() { hadamard_m<T>(out, n, m, scale);
float n_scale = m > 1 ? 1.0 : scale; }
hadamard_n<T>(out_ptr, n, m, n_scale, size);
if (m > 1) {
hadamard_m<T>(out_ptr, n, m, scale, size);
}
});
} }
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) { void Hadamard::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
// Copy input to output // Copy input to output
if (in.flags().row_contiguous && in.is_donatable()) { copy(in, out, CopyType::General);
out.copy_shared_buffer(in);
} else {
copy(
in,
out,
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
}
int axis = out.ndim() - 1; int axis = out.ndim() - 1;
auto [n, m] = decompose_hadamard(out.shape(axis)); auto [n, m] = decompose_hadamard(out.shape(axis));
switch (in.dtype()) { switch (in.dtype()) {
case float32: case float32:
return hadamard<float>(out, n, m, scale_, stream()); return hadamard<float>(out, n, m, scale_);
case float16: case float16:
return hadamard<float16_t>(out, n, m, scale_, stream()); return hadamard<float16_t>(out, n, m, scale_);
case bfloat16: case bfloat16:
return hadamard<bfloat16_t>(out, n, m, scale_, stream()); return hadamard<bfloat16_t>(out, n, m, scale_);
default: default:
throw std::invalid_argument("[hadamard] Unsupported type."); throw std::invalid_argument("[hadamard] Unsupported type.");
} }
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -0,0 +1,394 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
#include "mlx/allocator.h"
#include "mlx/primitives.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
template <typename IdxT>
inline size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
template <>
inline size_t offset_neg_idx(bool idx, size_t) {
return idx;
}
template <>
inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
}
template <typename T, typename IdxT>
void gather(
const array& src,
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const std::vector<int>& slice_sizes) {
// If the array is row contiguous then we can do a contiguous copy given
// two conditions on the slice size:
// - Any number of leading ones in the slice sizes are allowed
// - All other slice sizes match the corresponding dimension except the
// first non-singleton slice size
// If the array is col contiguous then the reverse is the case:
// - Any number of trailing ones in the slice sizes are allowed
// - All other slice sizes match the corresponding dimension except the
// first non-singleton slice size from the end
bool can_copy = false;
if (src.flags().row_contiguous) {
can_copy = true;
// Ignore leading 1s
int i = 0;
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
;
// Check the remaining
i++;
for (; i < src.ndim() && can_copy; ++i) {
can_copy = (src.shape(i) == slice_sizes[i]);
}
} else if (src.flags().col_contiguous) {
can_copy = true;
// Ignore trailing 1s
int i = slice_sizes.size() - 1;
for (; i >= 0 && slice_sizes[i] == 1; --i)
;
// Skip the next slice size and check the remaining
i--;
for (; i >= 0 && can_copy; --i) {
can_copy = (src.shape(i) == slice_sizes[i]);
}
}
size_t slice_size = 1;
for (auto s : slice_sizes) {
slice_size *= s;
}
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
size_t out_idx = 0;
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> src_it;
if (!can_copy && src.ndim() > 0) {
src_it = std::move(
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
}
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii];
auto idx_loc = its[ii].loc;
its[ii].step();
auto idx_val =
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]);
}
if (slice_size == 1) {
dst_ptr[out_idx++] = src_ptr[src_idx];
} else if (can_copy) {
std::copy(
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
out_idx += slice_size;
} else {
for (int jj = 0; jj < slice_size; jj++) {
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
src_it.step();
}
src_it.reset();
}
}
}
template <typename IdxT>
void dispatch_gather(
const array& src,
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const std::vector<int>& size) {
switch (out.dtype()) {
case bool_:
gather<bool, IdxT>(src, inds, out, axes, size);
break;
case uint8:
gather<uint8_t, IdxT>(src, inds, out, axes, size);
break;
case uint16:
gather<uint16_t, IdxT>(src, inds, out, axes, size);
break;
case uint32:
gather<uint32_t, IdxT>(src, inds, out, axes, size);
break;
case uint64:
gather<uint64_t, IdxT>(src, inds, out, axes, size);
break;
case int8:
gather<int8_t, IdxT>(src, inds, out, axes, size);
break;
case int16:
gather<int16_t, IdxT>(src, inds, out, axes, size);
break;
case int32:
gather<int32_t, IdxT>(src, inds, out, axes, size);
break;
case int64:
gather<int64_t, IdxT>(src, inds, out, axes, size);
break;
case float16:
gather<float16_t, IdxT>(src, inds, out, axes, size);
break;
case float32:
gather<float, IdxT>(src, inds, out, axes, size);
break;
case bfloat16:
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
break;
case complex64:
gather<complex64_t, IdxT>(src, inds, out, axes, size);
break;
}
}
void Gather::eval(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end());
if (inds.empty()) {
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
return;
}
switch (inds[0].dtype()) {
case bool_:
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
break;
case uint8:
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint16:
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint32:
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
break;
case uint64:
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
break;
case int8:
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
break;
case int16:
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
break;
case int32:
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
break;
case int64:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
break;
case float16:
case float32:
case bfloat16:
case complex64:
throw std::runtime_error(
"[Gather::eval] Cannot gather with floating point indices.");
break;
}
}
template <typename InT, typename IdxT, typename OpT>
void scatter(
const array& updates,
array& out,
const std::vector<array>& inds,
const std::vector<int>& axes,
const OpT& op) {
int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1;
std::vector<int> update_shape(
updates.shape().begin() + inds_ndim, updates.shape().end());
size_t update_size = 1;
for (auto us : update_shape) {
update_size *= us;
}
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> update_it(updates);
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < nind; ++j) {
auto ax = axes[j];
auto idx_loc = its[j].loc;
its[j].step();
auto idx_val =
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]);
}
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) {
op(updates.data<InT>()[update_it.loc],
out.data<InT>() + out_offset + out_it.loc);
update_it.step();
out_it.step();
}
out_it.reset();
update_it.reset();
}
}
template <typename InT, typename IdxT>
void dispatch_scatter_inds(
array& out,
const std::vector<array>& indices,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
switch (rtype) {
case Scatter::None:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
break;
case Scatter::Sum:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
break;
case Scatter::Prod:
scatter<InT, IdxT>(
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
break;
case Scatter::Max:
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
(*y) = (*y > x) ? *y : x;
});
break;
case Scatter::Min:
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
(*y) = (*y < x) ? *y : x;
});
break;
}
}
template <typename InT>
void dispatch_scatter(
array& out,
const std::vector<array>& inds,
const array& updates,
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
if (inds.empty()) {
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
return;
}
switch (inds[0].dtype()) {
case bool_:
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
break;
case uint8:
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
break;
case uint16:
dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
break;
case uint32:
dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
break;
case uint64:
dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
break;
case int8:
dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
break;
case int16:
dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
break;
case int32:
dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
break;
case int64:
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
break;
case float16:
case float32:
case bfloat16:
case complex64:
throw std::runtime_error(
"[Scatter::eval_cpu] Cannot scatter with floating point indices.");
}
}
void Scatter::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2);
auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
auto& updates = inputs.back();
// Copy src into out (copy allocates memory for out)
copy(src, out, CopyType::General);
switch (src.dtype()) {
case bool_:
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
break;
case uint8:
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint16:
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint32:
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
break;
case uint64:
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
break;
case int8:
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
break;
case int16:
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
break;
case int32:
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
break;
case int64:
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
break;
case float16:
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
break;
case float32:
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
break;
case complex64:
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,120 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info;
MLX_LAPACK_FUNC(strtri)
(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
return info;
}
namespace mlx::core {
void general_inv(array& inv, int N, int i) {
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void tri_inv(array& inv, int N, int i, bool upper) {
const char uplo = upper ? 'L' : 'U';
const char diag = 'N';
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: triangular inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
// Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see
// https://math.stackexchange.com/a/340234):
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output.
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
for (int i = 0; i < num_matrices; i++) {
if (tri) {
tri_inv(inv, N, i, upper);
} else {
general_inv(inv, N, i);
}
}
}
void Inverse::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32.");
}
inverse_impl(inputs[0], output, tri_, upper_);
}
} // namespace mlx::core

View File

@@ -0,0 +1,24 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#include <lapack.h>
#endif
#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME)
// This is to work around a change in the function signatures of lapack >= 3.9.1
// where functions taking char* also include a strlen argument, see a similar
// change in OpenCV:
// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57
#define MLX_LAPACK_FUNC(f) LAPACK_##f
#else
#define MLX_LAPACK_FUNC(f) f##_
#endif

View File

@@ -1,10 +1,12 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert>
#include <utility> #include <utility>
#include "mlx/allocator.h"
#include "mlx/backend/common/load.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace { namespace {
@@ -27,31 +29,33 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core { namespace mlx::core {
void Load::eval_cpu(const std::vector<array>& inputs, array& out) { void load(
out.set_data(allocator::malloc_or_wait(out.nbytes())); array& out,
auto read_task = [out_ptr = out.data<char>(), size_t offset,
size = out.size(), const std::shared_ptr<io::Reader>& reader,
itemsize = out.itemsize(), bool swap_endianness_) {
offset = offset_, reader->read(out.data<char>(), out.nbytes(), offset);
reader = reader_,
swap_endianness_ = swap_endianness_]() mutable { if (swap_endianness_) {
reader->read(out_ptr, size * itemsize, offset); switch (out.itemsize()) {
if (swap_endianness_) { case 2:
switch (itemsize) { swap_endianness<2>(out.data<uint8_t>(), out.data_size());
case 2: break;
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size); case 4:
break; swap_endianness<4>(out.data<uint8_t>(), out.data_size());
case 4: break;
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size); case 8:
break; swap_endianness<8>(out.data<uint8_t>(), out.data_size());
case 8: break;
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
}
} }
}; }
auto fut = io::thread_pool().enqueue(std::move(read_task)).share(); }
scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); });
void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
} }
} // namespace mlx::core } // namespace mlx::core

14
mlx/backend/common/load.h Normal file
View File

@@ -0,0 +1,14 @@
// Copyright © 2024 Apple Inc.
#include "mlx/array.h"
#include "mlx/io/load.h"
namespace mlx::core {
void load(
array& out,
size_t offset,
const std::shared_ptr<io::Reader>& reader,
bool swap_endianess);
} // namespace mlx::core

View File

@@ -10,24 +10,20 @@ OUTPUT_FILE=$1
GCC=$2 GCC=$2
SRCDIR=$3 SRCDIR=$3
CLANG=$4 CLANG=$4
ARCH=$5
if [ "$CLANG" = "TRUE" ]; then if [ "$CLANG" = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM read -r -d '' INCLUDES <<- EOM
#include <cmath> #include <cmath>
#include <complex> #include <complex>
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
#include <arm_fp16.h>
#endif
EOM EOM
CC_FLAGS="-arch ${ARCH} -nobuiltininc -nostdinc" CC_FLAGS=""
else else
CC_FLAGS="-std=c++17" CC_FLAGS="-std=c++17"
fi fi
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E -P "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 2>/dev/null) CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE" cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() { const char* get_kernel_preamble() {

View File

@@ -0,0 +1,300 @@
// Copyright © 2024 Apple Inc.
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, typename mask_t>
inline void mask_matrix(
T* data,
const mask_t* mask,
int block_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str,
const size_t mask_offset) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str];
if (do_mask != 1) {
int loc_x = i * block_size;
int loc_y = j * block_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
int size_x = std::min(block_size, X - loc_x);
int size_y = std::min(block_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
if constexpr (std::is_same_v<mask_t, bool>) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
} else {
data_block[ii * X_data_str + jj * Y_data_str] *= do_mask;
}
}
}
}
}
}
}
} // namespace
void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose =
[](const array& arr, bool do_copy, bool expand_all = false) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(false, stx, arr_copy);
}
return std::make_tuple(false, stx, arr);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(true, sty, arr_copy);
}
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
bool has_op_mask = inputs.size() > 3;
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
auto [a_transposed, lda, a] =
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
auto [b_transposed, ldb, b] =
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
auto mask_array = [](const array& mask,
float* data,
int block_size,
int batch_idx,
int X,
int Y,
size_t X_data_str,
size_t Y_data_str) {
size_t mask_offset = elem_to_loc(
mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
if (mask.dtype() == bool_) {
return mask_matrix(
data,
mask.data<bool>(),
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str,
mask_offset);
} else {
return mask_matrix(
data,
mask.data<float>(),
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str,
mask_offset);
}
};
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
// Adjust pointer
float* ai =
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
float* bi =
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides());
float* ci = out.data<float>() + M * N * i;
// Zero out blocks in a and b if needed
if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2];
mask_array(
a_mask,
ai,
block_size_,
i,
M,
K,
a_transposed ? 1 : lda,
a_transposed ? lda : 1);
auto& b_mask = inputs[inputs.size() - 1];
mask_array(
b_mask,
bi,
block_size_,
i,
K,
N,
b_transposed ? 1 : ldb,
b_transposed ? ldb : 1);
}
// Do matmul
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0, // alpha
ai,
lda,
bi,
ldb,
0.0, // beta
ci,
out.shape(-1) // ldc
);
// Zero out blocks in out
if (has_out_mask) {
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
}
}
}
void GatherMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[GatherMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
// Get batch dims
auto batch_size_out = out.size() / (M * N);
size_t matrix_stride_out = M * N;
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
};
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < batch_size_out; i++) {
// Get index
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda,
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb,
0.0f, // beta
out.data<float>() + matrix_stride_out * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core

675
mlx/backend/common/ops.h Normal file
View File

@@ -0,0 +1,675 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <stdint.h>
#include <cmath>
#include <complex>
namespace mlx::core::detail {
namespace {
constexpr float inf = std::numeric_limits<float>::infinity();
} // namespace
typedef union {
int i;
float f;
} IntOrFloat;
inline float fast_exp(float x) {
if (x == -std::numeric_limits<float>::infinity()) {
return 0.0f;
} else if (x == std::numeric_limits<float>::infinity() || std::isnan(x)) {
return x;
}
x *= 1.442695; // multiply with log_2(e)
float ipart, fpart;
IntOrFloat epart;
x = std::max(-80.f, std::min(x, 80.f));
ipart = std::floor(x + 0.5);
fpart = x - ipart;
x = 1.535336188319500e-4f;
x = x * fpart + 1.339887440266574e-3f;
x = x * fpart + 9.618437357674640e-3f;
x = x * fpart + 5.550332471162809e-2f;
x = x * fpart + 2.402264791363012e-1f;
x = x * fpart + 6.931472028550421e-1f;
x = x * fpart + 1.000000000000000f;
// generate 2**ipart in the floating point representation using integer
// bitshifting
epart.i = (int(ipart) + 127) << 23;
return epart.f * x;
}
inline float fast_erf(float a) {
float r, s, t, u;
t = std::abs(a);
s = a * a;
if (t > 0.927734375f) {
// maximum error 0.99527 ulp
r = std::fma(
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
u = std::fma(
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
r = std::fma(r, s, u);
r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
r = std::fma(r, t, -t);
// TODO, replace with expm1 when implemented
r = 1.0f - std::exp(r);
r = std::copysign(r, a);
} else {
// maximum error 0.98929 ulp
r = -5.96761703e-4f; // -0x1.38e000p-11
r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
r = std::fma(r, a, a);
}
return r;
}
inline float fast_erfinv(float a) {
auto t = std::fma(a, 0.0f - a, 1.0f);
t = std::log(t);
float p;
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
} else { // maximum ulp error = 2.35002
p = 5.43877832e-9f; // 0x1.75c000p-28
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
}
return a * p;
}
struct Abs {
template <typename T>
T operator()(T x) {
return std::abs(x);
}
uint8_t operator()(uint8_t x) {
return x;
}
uint16_t operator()(uint16_t x) {
return x;
}
uint32_t operator()(uint32_t x) {
return x;
}
uint64_t operator()(uint64_t x) {
return x;
}
bool operator()(bool x) {
return x;
}
};
struct ArcCos {
template <typename T>
T operator()(T x) {
return std::acos(x);
}
};
struct ArcCosh {
template <typename T>
T operator()(T x) {
return std::acosh(x);
}
};
struct ArcSin {
template <typename T>
T operator()(T x) {
return std::asin(x);
}
};
struct ArcSinh {
template <typename T>
T operator()(T x) {
return std::asinh(x);
}
};
struct ArcTan {
template <typename T>
T operator()(T x) {
return std::atan(x);
}
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return std::atan2(y, x);
}
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
return std::atanh(x);
}
};
struct Ceil {
template <typename T>
T operator()(T x) {
return std::ceil(x);
}
int8_t operator()(int8_t x) {
return x;
}
int16_t operator()(int16_t x) {
return x;
}
int32_t operator()(int32_t x) {
return x;
}
int64_t operator()(int64_t x) {
return x;
}
uint8_t operator()(uint8_t x) {
return x;
}
uint16_t operator()(uint16_t x) {
return x;
}
uint32_t operator()(uint32_t x) {
return x;
}
uint64_t operator()(uint64_t x) {
return x;
}
bool operator()(bool x) {
return x;
}
};
struct Conjugate {
complex64_t operator()(complex64_t x) {
return std::conj(x);
}
};
struct Cos {
template <typename T>
T operator()(T x) {
return std::cos(x);
}
};
struct Cosh {
template <typename T>
T operator()(T x) {
return std::cosh(x);
}
};
struct Erf {
template <typename T>
T operator()(T x) {
return static_cast<T>(fast_erf(static_cast<float>(x)));
}
};
struct ErfInv {
template <typename T>
T operator()(T x) {
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
}
};
struct Exp {
template <typename T>
T operator()(T x) {
return fast_exp(x);
}
complex64_t operator()(complex64_t x) {
return std::exp(x);
}
};
struct Expm1 {
template <typename T>
T operator()(T x) {
return expm1(x);
}
};
struct Floor {
template <typename T>
T operator()(T x) {
return std::floor(x);
}
int8_t operator()(int8_t x) {
return x;
}
int16_t operator()(int16_t x) {
return x;
}
int32_t operator()(int32_t x) {
return x;
}
int64_t operator()(int64_t x) {
return x;
}
uint8_t operator()(uint8_t x) {
return x;
}
uint16_t operator()(uint16_t x) {
return x;
}
uint32_t operator()(uint32_t x) {
return x;
}
uint64_t operator()(uint64_t x) {
return x;
}
bool operator()(bool x) {
return x;
}
};
struct Imag {
template <typename T>
T operator()(T x) {
return std::imag(x);
}
};
struct Log {
template <typename T>
T operator()(T x) {
return std::log(x);
}
};
struct Log2 {
template <typename T>
T operator()(T x) {
return std::log2(x);
}
};
struct Log10 {
template <typename T>
T operator()(T x) {
return std::log10(x);
}
};
struct Log1p {
template <typename T>
T operator()(T x) {
return log1p(x);
}
};
struct LogicalNot {
template <typename T>
T operator()(T x) {
return !x;
}
};
struct Negative {
template <typename T>
T operator()(T x) {
return -x;
}
};
struct Real {
template <typename T>
T operator()(T x) {
return std::real(x);
}
};
struct Round {
template <typename T>
T operator()(T x) {
return std::rint(x);
}
complex64_t operator()(complex64_t x) {
return {std::rint(x.real()), std::rint(x.imag())};
}
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
auto one = static_cast<decltype(x)>(1.0);
return one / (one + fast_exp(-x));
}
};
struct Sign {
template <typename T>
T operator()(T x) {
return (x > T(0)) - (x < T(0));
}
uint8_t operator()(uint8_t x) {
return x != 0;
}
uint16_t operator()(uint16_t x) {
return x != 0;
}
uint32_t operator()(uint32_t x) {
return x != 0;
}
uint64_t operator()(uint64_t x) {
return x != 0;
}
complex64_t operator()(complex64_t x) {
return x == complex64_t(0) ? x : x / std::abs(x);
}
};
struct Sin {
template <typename T>
T operator()(T x) {
return std::sin(x);
}
};
struct Sinh {
template <typename T>
T operator()(T x) {
return std::sinh(x);
}
};
struct Square {
template <typename T>
T operator()(T x) {
return x * x;
}
};
struct Sqrt {
template <typename T>
T operator()(T x) {
return std::sqrt(x);
}
};
struct Rsqrt {
template <typename T>
T operator()(T x) {
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
}
};
struct Tan {
template <typename T>
T operator()(T x) {
return std::tan(x);
}
};
struct Tanh {
template <typename T>
T operator()(T x) {
return std::tanh(x);
}
};
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
template <typename T>
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = numerator % denominator;
if (r != 0 && (r < 0 != denominator < 0))
r += denominator;
return r;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = std::fmod(numerator, denominator);
if (r != 0 && (r < 0 != denominator < 0)) {
r += denominator;
}
return r;
}
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
return numerator % denominator;
}
};
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
return x == y || (std::isnan(x) && std::isnan(y));
}
};
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct Maximum {
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
return (x > y) ? x : y;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
if (std::isnan(x)) {
return x;
}
return (x > y) ? x : y;
}
};
struct Minimum {
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
return x < y ? x : y;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
if (std::isnan(x)) {
return x;
}
return x < y ? x : y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
constexpr float inf = std::numeric_limits<float>::infinity();
auto maxval = Maximum()(x, y);
auto minval = Minimum()(x, y);
return (minval == -inf || maxval == inf)
? maxval
: static_cast<decltype(x)>(
maxval + std::log1p(fast_exp(minval - maxval)));
}
};
struct Multiply {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
};
struct Power {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
return std::pow(base, exp);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
};
struct Subtract {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
}
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
}
};
struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return condition ? x : y;
}
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
}
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
}
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
}
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
}
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
}
};
} // namespace mlx::core::detail

View File

@@ -0,0 +1,636 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/backend/common/arange.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/threefry.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), unsignedinteger)) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Abs());
}
}
void Arange::eval(const std::vector<array>& inputs, array& out) {
arange(inputs, out, start_, step_);
}
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCos());
} else {
throw std::invalid_argument(
"[arccos] Cannot compute inverse cosine of elements in array"
" with non floating point type.");
}
}
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCosh());
} else {
throw std::invalid_argument(
"[arccosh] Cannot compute inverse hyperbolic cosine of elements in"
" array with non floating point type.");
}
}
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSin());
} else {
throw std::invalid_argument(
"[arcsin] Cannot compute inverse sine of elements in array"
" with non floating point type.");
}
}
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSinh());
} else {
throw std::invalid_argument(
"[arcsinh] Cannot compute inverse hyperbolic sine of elements in"
" array with non floating point type.");
}
}
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTan());
} else {
throw std::invalid_argument(
"[arctan] Cannot compute inverse tangent of elements in array"
" with non floating point type.");
}
}
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTanh());
} else {
throw std::invalid_argument(
"[arctanh] Cannot compute inverse hyperbolic tangent of elements in"
" array with non floating point type.");
}
}
void AsType::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
}
void Ceil::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
std::vector<int> sizes;
sizes.push_back(0);
for (auto& p : inputs) {
sizes.push_back(p.shape(axis_));
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral);
}
}
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_fp(in, out, detail::Conjugate());
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
}
}
void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cos());
} else {
throw std::invalid_argument(
"[cos] Cannot compute cosine of elements in array"
" with non floating point type.");
}
}
void Cosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cosh());
} else {
throw std::invalid_argument(
"[cosh] Cannot compute hyperbolic cosine of elements in array"
" with non floating point type.");
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::Erf());
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
}
void ErfInv::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::ErfInv());
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
}
void Exp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Exp());
} else {
throw std::invalid_argument(
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
void Expm1::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Expm1());
} else {
throw std::invalid_argument(
"[expm1] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Full::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
assert(in.dtype() == out.dtype());
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy(in, out, ctype);
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log());
break;
case Base::two:
unary_fp(in, out, detail::Log2());
break;
case Base::ten:
unary_fp(in, out, detail::Log10());
break;
}
} else {
throw std::invalid_argument(
"[log] Cannot compute log of elements in array with"
" non floating point type.");
}
}
void Log1p::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Log1p());
} else {
throw std::invalid_argument(
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
}
}
void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::LogicalNot());
}
void Negative::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Negative());
}
void Pad::eval(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy(val, out, CopyType::Scalar);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes_.size(); i++) {
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
data_offset += out.strides()[ax] * low_pad_size_[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
}
void RandomBits::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
size_t num_keys = keys.size() / 2;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>();
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);
// Get ith key
auto kidx = 2 * i;
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides());
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides());
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
for (; count.first + 1 < half_size; count.first++, count.second++) {
std::tie(ptr[count.first], ptr[count.second]) =
random::threefry2x32_hash(key, count);
}
if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) {
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
}
if (!even) {
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
}
}
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
void Round::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sigmoid());
} else {
throw std::invalid_argument(
"[sigmoid] Cannot sigmoid of elements in array with"
" non floating point type.");
}
}
void Sign::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == bool_) {
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Sign());
}
}
void Sin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sin());
} else {
throw std::invalid_argument(
"[sin] Cannot compute sine of elements in array"
" with non floating point type.");
}
}
void Sinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sinh());
} else {
throw std::invalid_argument(
"[sinh] Cannot compute hyperbolic sine of elements in array"
" with non floating point type.");
}
}
void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] =
prepare_slice(in, start_indices_, strides_);
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ in,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General);
} else {
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(out);
// Do copy
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd_strides,
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void Square::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Square());
}
void Sqrt::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (recip_) {
unary_fp(in, out, detail::Rsqrt());
} else {
unary_fp(in, out, detail::Sqrt());
}
}
void Tan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tan());
} else {
throw std::invalid_argument(
"[tan] Cannot compute tangent of elements in array"
" with non floating point type.");
}
}
void Tanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tanh());
} else {
throw std::invalid_argument(
"[tanh] Cannot compute hyperbolic tangent of elements in array"
" with non floating point type.");
}
}
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < strides.size() - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_inplace(in_tmp, tmp, CopyType::General);
} else {
copy_inplace(in, tmp, CopyType::General);
}
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

148
mlx/backend/common/qrf.cpp Normal file
View File

@@ -0,0 +1,148 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
template <typename T>
struct lpack;
template <>
struct lpack<float> {
static void xgeqrf(
const int* m,
const int* n,
float* a,
const int* lda,
float* tau,
float* work,
const int* lwork,
int* info) {
sgeqrf_(m, n, a, lda, tau, work, lwork, info);
}
static void xorgqr(
const int* m,
const int* n,
const int* k,
float* a,
const int* lda,
const float* tau,
float* work,
const int* lwork,
int* info) {
sorgqr_(m, n, k, a, lda, tau, work, lwork, info);
}
};
template <typename T>
void qrf_impl(const array& a, array& q, array& r) {
const int M = a.shape(-2);
const int N = a.shape(-1);
const int lda = std::max(M, N);
size_t num_matrices = a.size() / (M * N);
int num_reflectors = std::min(M, N);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
// Copy A to inplace input and make it col-contiguous
array in(a.shape(), float32, nullptr, {});
auto flags = in.flags();
// Copy the input to be column contiguous
flags.col_contiguous = num_matrices == 1;
flags.row_contiguous = false;
std::vector<size_t> strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral);
T optimal_work;
int lwork = -1;
int info;
// Compute workspace size
lpack<T>::xgeqrf(
&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Solve
lpack<T>::xgeqrf(
&M,
&N,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
allocator::free(work);
r.set_data(allocator::malloc_or_wait(r.nbytes()));
copy_inplace(in, r, CopyType::General);
for (int i = 0; i < num_matrices; ++i) {
// Zero lower triangle
for (int j = 0; j < r.shape(-2); ++j) {
for (int k = 0; k < j; ++k) {
r.data<T>()[i * N * M + j * N + k] = 0;
}
}
}
// Get work size
lwork = -1;
lpack<T>::xorgqr(
&M,
&N,
&num_reflectors,
nullptr,
&lda,
nullptr,
&optimal_work,
&lwork,
&info);
lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
lpack<T>::xorgqr(
&M,
&N,
&num_reflectors,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
q.set_data(allocator::malloc_or_wait(q.nbytes()));
copy_inplace(in, q, CopyType::General);
// Cleanup
allocator::free(work);
allocator::free(tau);
}
void QRF::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[QRF::eval] only supports float32.");
}
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
}
} // namespace mlx::core

View File

@@ -0,0 +1,407 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/metal/copy.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, int bits, int group_size>
void _qmm(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Ng = N / group_size;
const int Nw = N / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
uint32_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
}
}
}
}
result += N;
}
}
template <typename T, int bits, int group_size>
void _qmm_t(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) {
uint32_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
}
}
}
*result = sum;
result++;
}
x += K;
}
}
template <typename T>
void _qmm_dispatch_typed(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
int group_size,
int bits,
bool transposed_w) {
switch (bits) {
case 2: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 4: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 8: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
}
std::ostringstream msg;
msg << "Quantization type not supported. Provided bits=" << bits
<< " and group_size=" << group_size
<< ". The supported options are bits in "
<< "{2, 4, 8} and group_size in {64, 128}.";
throw std::invalid_argument(msg.str());
}
void _qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
int bits,
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / x.shape(-1) / x.shape(-2);
for (int i = 0; i < batch_size; i++) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float>() + elem_to_loc(i * g_els, scales),
biases.data<float>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
}
void _bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
const array& lhs_indices,
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
}
} // namespace
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto ensure_row_contiguous_last_dims = [](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
}
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More