mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 14:34:37 +08:00
Compare commits
119 Commits
checkpoint
...
v0.11.1
Author | SHA1 | Date | |
---|---|---|---|
![]() |
b0012cdd0f | ||
![]() |
84d61d27aa | ||
![]() |
ed83908931 | ||
![]() |
ef5f7d1aea | ||
![]() |
090ff659dc | ||
![]() |
85c8a91a27 | ||
![]() |
581b699ac9 | ||
![]() |
8a0677d56d | ||
![]() |
b18468bf81 | ||
![]() |
107ba2891a | ||
![]() |
cd9e184529 | ||
![]() |
2e7c02d5cd | ||
![]() |
ae18326533 | ||
![]() |
91eba8e485 | ||
![]() |
d07e295c62 | ||
![]() |
dce4bd74a4 | ||
![]() |
ffff671273 | ||
![]() |
12d4507ee3 | ||
![]() |
8580d997ff | ||
![]() |
061cf9a4ce | ||
![]() |
99abb9eff4 | ||
![]() |
fffe072028 | ||
![]() |
a1a31eed27 | ||
![]() |
ae812350f9 | ||
![]() |
b63ef10a7f | ||
![]() |
42afe27e12 | ||
![]() |
76e63212ff | ||
![]() |
aac2f9fb61 | ||
![]() |
bddf23f175 | ||
![]() |
039da779d1 | ||
![]() |
d88d2124b5 | ||
![]() |
e142aaf8a1 | ||
![]() |
0caf35f4b8 | ||
![]() |
3fc993f82d | ||
![]() |
741eb28443 | ||
![]() |
1a87dc5ea8 | ||
![]() |
2427fa171e | ||
![]() |
639e06e1f3 | ||
![]() |
02fedbf1da | ||
![]() |
110d9b149d | ||
![]() |
9cbff5ec1d | ||
![]() |
433c0206b0 | ||
![]() |
8915901966 | ||
![]() |
f48bc496c7 | ||
![]() |
913b19329c | ||
![]() |
d8cb3128f6 | ||
![]() |
5f9ba3019f | ||
![]() |
46caf0bef0 | ||
![]() |
45f636e759 | ||
![]() |
a7b404ff53 | ||
![]() |
c4fd0e5ede | ||
![]() |
bab5386306 | ||
![]() |
aca7584635 | ||
![]() |
d611251502 | ||
![]() |
f30b659291 | ||
![]() |
90dfa43ff1 | ||
![]() |
dc175f08d3 | ||
![]() |
29221fa238 | ||
![]() |
a789685c63 | ||
![]() |
240d10699c | ||
![]() |
925014b661 | ||
![]() |
5611e1a95e | ||
![]() |
570f2bf29e | ||
![]() |
9948eddf11 | ||
![]() |
a3ee03da01 | ||
![]() |
28fcd2b519 | ||
![]() |
8e686764ac | ||
![]() |
479051ce1c | ||
![]() |
bfb5bad4f0 | ||
![]() |
1e16331d9c | ||
![]() |
be98f4ab6b | ||
![]() |
6ee1112f30 | ||
![]() |
8e5a5a1ccd | ||
![]() |
fcda3a0e66 | ||
![]() |
9663c22fe9 | ||
![]() |
f0ae00da12 | ||
![]() |
44390bd3d0 | ||
![]() |
2225374060 | ||
![]() |
105d236889 | ||
![]() |
53e6a9367c | ||
![]() |
f5a1582fe8 | ||
![]() |
a54f06b16f | ||
![]() |
4650d94d98 | ||
![]() |
a5681ebc52 | ||
![]() |
e849b3424a | ||
![]() |
b219d12a6b | ||
![]() |
cec8661113 | ||
![]() |
73a8c090e0 | ||
![]() |
db6796ac61 | ||
![]() |
9a8ee00246 | ||
![]() |
d39ed54f8e | ||
![]() |
16546c70d8 | ||
![]() |
eaba55c9bf | ||
![]() |
19ec023256 | ||
![]() |
63ab0ab580 | ||
![]() |
8dfc376c00 | ||
![]() |
1efee9db09 | ||
![]() |
43abc402d8 | ||
![]() |
3f8b1668c4 | ||
![]() |
76c919b4ec | ||
![]() |
29d0c10ee5 | ||
![]() |
5ad133f8bb | ||
![]() |
d0c544a868 | ||
![]() |
ffb19df3c0 | ||
![]() |
8b7532b9ab | ||
![]() |
366478c560 | ||
![]() |
8e5600022a | ||
![]() |
0e95b64942 | ||
![]() |
0ae22b915b | ||
![]() |
7c441600fe | ||
![]() |
a4d290adb9 | ||
![]() |
28301807c2 | ||
![]() |
74ed0974b3 | ||
![]() |
ec8a4864fa | ||
![]() |
b7588fd5d7 | ||
![]() |
f512b905c7 | ||
![]() |
afd5274049 | ||
![]() |
1074674e32 | ||
![]() |
7762e07fde |
@@ -31,8 +31,7 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install pybind11-stubgen
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
@@ -44,7 +43,8 @@ jobs:
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
python3 setup.py generate_stubs
|
||||
echo "stubs"
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@@ -63,21 +63,24 @@ jobs:
|
||||
command: ./build/tests/tests
|
||||
|
||||
mac_build_and_test:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
macos:
|
||||
xcode: "15.2.0"
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.large.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.9
|
||||
python3.9 -m venv env
|
||||
brew install python@3.8
|
||||
python3.8 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install pybind11-stubgen
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
@@ -91,13 +94,13 @@ jobs:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
# TODO: Reenable when extension api becomes stable
|
||||
# - run:
|
||||
# name: Build example extension
|
||||
@@ -140,9 +143,8 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install --upgrade setuptools
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
pip install twine
|
||||
pip install build
|
||||
@@ -157,7 +159,7 @@ jobs:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
@@ -205,9 +207,8 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install --upgrade setuptools
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
@@ -215,7 +216,7 @@ jobs:
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
pip install . -v
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python -m build --wheel
|
||||
@@ -235,7 +236,10 @@ workflows:
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- mac_build_and_test
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
- linux_build_and_test
|
||||
|
||||
build_pypi_release:
|
||||
@@ -254,7 +258,7 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
xcode_version: ["14.3.1", "15.2.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
prb:
|
||||
when:
|
||||
@@ -268,6 +272,9 @@ workflows:
|
||||
context: pr-approval
|
||||
- mac_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
@@ -280,7 +287,7 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
xcode_version: ["14.3.1", "15.2.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
weekly_build:
|
||||
when:
|
||||
and:
|
||||
@@ -291,7 +298,7 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
xcode_version: ["14.3.1", "15.2.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
linux_test_release:
|
||||
when:
|
||||
|
@@ -1,11 +1,11 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v17.0.6
|
||||
rev: v18.1.3
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.2.0
|
||||
rev: 24.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
|
@@ -15,6 +15,8 @@ MLX was developed with contributions from the following individuals:
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
@@ -15,31 +15,33 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.5.1)
|
||||
set(MLX_VERSION 0.11.1)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
||||
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
||||
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
||||
|
||||
set(MLX_BUILD_ARM OFF)
|
||||
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
|
||||
message(FATAL_ERROR
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, check the build"
|
||||
" documentation for possible fixes: "
|
||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
||||
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
message(WARNING
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, "
|
||||
" make sure you are building for arm64.")
|
||||
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
if(NOT MLX_ENABLE_X64_MAC)
|
||||
message(FATAL_ERROR
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, check the build"
|
||||
" documentation for possible fixes: "
|
||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
||||
else()
|
||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||
endif()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
set(MLX_BUILD_ARM ON)
|
||||
endif()
|
||||
|
||||
@@ -64,8 +66,14 @@ endif()
|
||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
set(MLX_METAL_DEBUG OFF)
|
||||
elseif (MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
|
||||
if (MLX_METAL_DEBUG)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION
|
||||
@@ -74,18 +82,19 @@ elseif (MLX_BUILD_METAL)
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
|
||||
else()
|
||||
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
|
||||
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
metal_cpp
|
||||
URL ${METAL_CPP_URL}
|
||||
PATCH_COMMAND patch -N -i ${METAL_CPP_PATCH} || true
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
@@ -110,7 +119,27 @@ if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
#set(BLA_VENDOR Generic)
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
set(BLA_VENDOR OpenBLAS)
|
||||
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||
endif()
|
||||
# Search and link with lapack.
|
||||
find_package(LAPACK REQUIRED)
|
||||
if (NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
/usr/local/opt/openblas/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old version
|
||||
# of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
if (NOT BLAS_FOUND)
|
||||
message(FATAL_ERROR "Must have BLAS installed")
|
||||
@@ -124,17 +153,6 @@ else()
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||
find_package(LAPACK REQUIRED)
|
||||
if (NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
@@ -148,8 +166,12 @@ target_include_directories(
|
||||
|
||||
if (MLX_BUILD_PYTHON_BINDINGS)
|
||||
message(STATUS "Building Python bindings.")
|
||||
find_package(Python COMPONENTS Interpreter Development)
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||
endif()
|
||||
|
||||
|
@@ -17,14 +17,13 @@
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
|
||||
#define TIMEM(MSG, FUNC, ...) \
|
||||
std::cout << "Timing " \
|
||||
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
#define TIMEM(MSG, FUNC, ...) \
|
||||
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
|
||||
<< std::flush << std::setprecision(5) \
|
||||
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
|
||||
|
||||
template <typename F, typename... Args>
|
||||
double time_fn(F fn, Args... args) {
|
||||
double time_fn(F fn, Args&&... args) {
|
||||
// warmup
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
eval(fn(std::forward<Args>(args)...));
|
||||
|
57
benchmarks/python/fft_bench.py
Normal file
57
benchmarks/python/fft_bench.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def bandwidth_gb(runtime_ms, system_size):
|
||||
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
|
||||
bytes_per_gb = 1e9
|
||||
ms_per_s = 1e3
|
||||
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
|
||||
|
||||
|
||||
def run_bench(system_size):
|
||||
def fft(x):
|
||||
out = mx.fft.fft(x)
|
||||
mx.eval(out)
|
||||
return out
|
||||
|
||||
bandwidths = []
|
||||
for k in range(4, 12):
|
||||
n = 2**k
|
||||
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
|
||||
x = x.astype(mx.complex64)
|
||||
mx.eval(x)
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
|
||||
|
||||
return bandwidths
|
||||
|
||||
|
||||
def time_fft():
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=int(2**22))
|
||||
|
||||
with mx.stream(mx.gpu):
|
||||
gpu_bandwidths = run_bench(system_size=int(2**29))
|
||||
|
||||
# plot bandwidths
|
||||
x = [2**k for k in range(4, 12)]
|
||||
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
|
||||
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
|
||||
plt.title("MLX FFT Benchmark")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig("fft_plot.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_fft()
|
41
benchmarks/python/layer_norm_bench.py
Normal file
41
benchmarks/python/layer_norm_bench.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def layer_norm(x, w, b, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
mu = mx.mean(x, -1, keepdims=True)
|
||||
v = mx.var(x, -1, keepdims=True)
|
||||
return (x - mu) * mx.rsqrt(v + eps) * w + b
|
||||
|
||||
|
||||
def time_layer_norm():
|
||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x, w, b):
|
||||
gx, gw, gb = x, w, b
|
||||
for _ in range(32):
|
||||
gx, gw, gb = g(gx, gw, gb, y)
|
||||
return gx, gw, gb
|
||||
|
||||
time_fn(layer_norm_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_layer_norm()
|
39
benchmarks/python/rms_norm_bench.py
Normal file
39
benchmarks/python/rms_norm_bench.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def rms_norm(x, w, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return (x * n).astype(ot) * w
|
||||
|
||||
|
||||
def time_rms_norm():
|
||||
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1))
|
||||
g2 = mx.grad(f2, argnums=(0, 1))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, y)
|
||||
|
||||
def rms_norm_loop(g, x, w):
|
||||
gx, gw = x, w
|
||||
for _ in range(32):
|
||||
gx, gw = g(gx, gw, y)
|
||||
return gx, gw
|
||||
|
||||
time_fn(rms_norm_loop, g1, x, w)
|
||||
time_fn(rms_norm_loop, g2, x, w)
|
||||
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
||||
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_rms_norm()
|
@@ -6,21 +6,21 @@ from time_utils import time_fn
|
||||
|
||||
|
||||
def time_rope():
|
||||
rope = nn.RoPE(4096)
|
||||
rope = nn.RoPE(64)
|
||||
|
||||
# vec
|
||||
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
|
||||
mx.eval(x)
|
||||
|
||||
def rope_vec(x):
|
||||
for _ in range(32):
|
||||
x = rope(x)
|
||||
x = rope(x, offset=100)
|
||||
return x
|
||||
|
||||
time_fn(rope_vec, x)
|
||||
|
||||
# matrix
|
||||
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
|
||||
mx.eval(x)
|
||||
|
||||
def rope_mat(x):
|
||||
|
36
cmake/metal.14.0.diff
Normal file
36
cmake/metal.14.0.diff
Normal file
@@ -0,0 +1,36 @@
|
||||
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
|
||||
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
|
||||
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
|
||||
@@ -62,6 +62,7 @@
|
||||
|
||||
uint64_t signaledValue() const;
|
||||
void setSignaledValue(uint64_t signaledValue);
|
||||
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
|
||||
};
|
||||
|
||||
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
|
||||
@@ -138,6 +139,11 @@
|
||||
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
|
||||
{
|
||||
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
|
||||
+}
|
||||
+
|
||||
+// method: waitUntilSignaledValue
|
||||
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
|
||||
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
|
||||
}
|
||||
|
||||
// static method: alloc
|
||||
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
|
||||
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
|
||||
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
|
||||
@@ -1906,6 +1906,9 @@
|
||||
"setShouldMaximizeConcurrentCompilation:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
|
||||
"setSignaledValue:");
|
||||
+_MTL_PRIVATE_DEF_SEL(
|
||||
+ waitUntilSignaledValue_timeoutMS_,
|
||||
+ "waitUntilSignaledValue:timeoutMS:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSize_,
|
||||
"setSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSlice_,
|
36
cmake/metal.14.2.diff
Normal file
36
cmake/metal.14.2.diff
Normal file
@@ -0,0 +1,36 @@
|
||||
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
|
||||
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
|
||||
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
|
||||
@@ -62,6 +62,7 @@
|
||||
|
||||
uint64_t signaledValue() const;
|
||||
void setSignaledValue(uint64_t signaledValue);
|
||||
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
|
||||
};
|
||||
|
||||
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
|
||||
@@ -138,6 +139,11 @@
|
||||
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
|
||||
{
|
||||
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
|
||||
+}
|
||||
+
|
||||
+// method: waitUntilSignaledValue
|
||||
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
|
||||
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
|
||||
}
|
||||
|
||||
// static method: alloc
|
||||
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
|
||||
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
|
||||
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
|
||||
@@ -1918,6 +1918,9 @@
|
||||
"setShouldMaximizeConcurrentCompilation:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
|
||||
"setSignaledValue:");
|
||||
+_MTL_PRIVATE_DEF_SEL(
|
||||
+ waitUntilSignaledValue_timeoutMS_,
|
||||
+ "waitUntilSignaledValue:timeoutMS:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSize_,
|
||||
"setSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSlice_,
|
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 746 KiB |
20
docs/src/_templates/nn-module-template.rst
Normal file
20
docs/src/_templates/nn-module-template.rst
Normal file
@@ -0,0 +1,20 @@
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. currentmodule:: {{ module }}
|
||||
|
||||
.. autoclass:: {{ objname }}
|
||||
|
||||
{% block methods %}
|
||||
|
||||
{% if methods %}
|
||||
.. rubric:: {{ _('Methods') }}
|
||||
|
||||
.. autosummary::
|
||||
{% for item in methods %}
|
||||
{%- if item not in inherited_members and item != "__init__" %}
|
||||
~{{ name }}.{{ item }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
@@ -29,16 +29,17 @@ autosummary_generate = True
|
||||
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
|
||||
|
||||
intersphinx_mapping = {
|
||||
"https://docs.python.org/3": None,
|
||||
"https://numpy.org/doc/stable/": None,
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"numpy": ("https://numpy.org/doc/stable/", None),
|
||||
}
|
||||
|
||||
templates_path = ["_templates"]
|
||||
html_static_path = ["_static"]
|
||||
source_suffix = ".rst"
|
||||
master_doc = "index"
|
||||
main_doc = "index"
|
||||
highlight_language = "python"
|
||||
pygments_style = "sphinx"
|
||||
add_module_names = False
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
@@ -59,3 +60,22 @@ html_theme_options = {
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
htmlhelp_basename = "mlx_doc"
|
||||
|
||||
|
||||
def setup(app):
|
||||
from sphinx.util import inspect
|
||||
|
||||
wrapped_isfunc = inspect.isfunction
|
||||
|
||||
def isfunc(obj):
|
||||
type_name = str(type(obj))
|
||||
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
|
||||
return True
|
||||
return wrapped_isfunc(obj)
|
||||
|
||||
inspect.isfunction = isfunc
|
||||
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
|
||||
|
@@ -1,24 +1,16 @@
|
||||
Developer Documentation
|
||||
=======================
|
||||
|
||||
MLX provides a open and flexible backend to which users may add operations
|
||||
and specialized implementations without much hassle. While the library supplies
|
||||
efficient operations that can be used and composed for any number of
|
||||
applications, there may arise cases where new functionalities or highly
|
||||
optimized implementations are needed. For such cases, you may design and
|
||||
implement your own operations that link to and build on top of :mod:`mlx.core`.
|
||||
We will introduce the inner-workings of MLX and go over a simple example to
|
||||
learn the steps involved in adding new operations to MLX with your own CPU
|
||||
and GPU implementations.
|
||||
You can extend MLX with custom operations on the CPU or GPU. This guide
|
||||
explains how to do that with a simple example.
|
||||
|
||||
Introducing the Example
|
||||
-----------------------
|
||||
|
||||
Let's say that you would like an operation that takes in two arrays,
|
||||
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
|
||||
respectively, and then adds them together to get the result
|
||||
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
||||
writing out a function as follows:
|
||||
Let's say you would like an operation that takes in two arrays, ``x`` and
|
||||
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
|
||||
and then adds them together to get the result ``z = alpha * x + beta * y``.
|
||||
You can do that in MLX directly:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -27,44 +19,35 @@ writing out a function as follows:
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
This function performs that operation while leaving the implementations and
|
||||
differentiation to MLX.
|
||||
This function performs that operation while leaving the implementation and
|
||||
function transformations to MLX.
|
||||
|
||||
However, you work with vector math libraries often and realize that the
|
||||
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
|
||||
You would really like the part of your applications that does this operation
|
||||
on the CPU to be very fast - so you decide that you want it to rely on the
|
||||
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
|
||||
our assumptions on to you, let's also assume that you want to learn how to add
|
||||
your own implementation for the gradients of your new operation while going
|
||||
over the ins-and-outs of the MLX framework.
|
||||
However you may need to customize the underlying implementation, perhaps to
|
||||
make it faster or for custom differentiation. In this tutorial we will go
|
||||
through adding custom extensions. It will cover:
|
||||
|
||||
Well, what a coincidence! You are in the right place. Over the course of this
|
||||
example, we will learn:
|
||||
|
||||
* The structure of the MLX library from the frontend API to the backend implementations.
|
||||
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
|
||||
* How to implement your own GPU implementation using metal.
|
||||
* How to add your own ``vjp`` and ``jvp``.
|
||||
* How to build your implementations, link them to MLX, and bind them to python.
|
||||
* The structure of the MLX library.
|
||||
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
|
||||
* Implementing a GPU operation using metal.
|
||||
* Adding the ``vjp`` and ``jvp`` function transformation.
|
||||
* Building a custom extension and binding it to python.
|
||||
|
||||
Operations and Primitives
|
||||
-------------------------
|
||||
|
||||
In one sentence, operations in MLX build the computation graph, and primitives
|
||||
provide the rules for evaluation and transformations of said graph. Let's start
|
||||
by discussing operations in more detail.
|
||||
Operations in MLX build the computation graph. Primitives provide the rules for
|
||||
evaluating and transforming the graph. Let's start by discussing operations in
|
||||
more detail.
|
||||
|
||||
Operations
|
||||
^^^^^^^^^^^
|
||||
|
||||
Operations are the frontend functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
|
||||
operations in the Python API (:ref:`ops`).
|
||||
Operations are the front-end functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
||||
|
||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
|
||||
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
|
||||
C++ API:
|
||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
|
||||
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
||||
C++:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -83,10 +66,7 @@ C++ API:
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
|
||||
This operation itself can call other operations within it if needed. So, the
|
||||
simplest way to go about implementing this operation would be do so in terms
|
||||
of existing operations.
|
||||
The simplest way to this operation is in terms of existing operations:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -100,25 +80,23 @@ of existing operations.
|
||||
// Scale x and y on the provided stream
|
||||
auto ax = multiply(array(alpha), x, s);
|
||||
auto by = multiply(array(beta), y, s);
|
||||
|
||||
|
||||
// Add and return
|
||||
return add(ax, by, s);
|
||||
}
|
||||
|
||||
However, as we discussed earlier, this is not our goal. The operations themselves
|
||||
do not contain the implementations that act on the data, nor do they contain the
|
||||
rules of transformations. Rather, they are an easy to use interface that build
|
||||
on top of the building blocks we call :class:`Primitive`.
|
||||
The operations themselves do not contain the implementations that act on the
|
||||
data, nor do they contain the rules of transformations. Rather, they are an
|
||||
easy to use interface that use :class:`Primitive` building blocks.
|
||||
|
||||
Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create an output given a set of input :class:`array` . Further,
|
||||
a :class:`Primitive` is a class that contains rules on how it is evaluated
|
||||
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
|
||||
``jvp``. These words on their own can be a bit abstract, so lets take a step
|
||||
back and go to our example to give ourselves a more concrete image.
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||
more concrete:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -134,11 +112,15 @@ back and go to our example to give ourselves a more concrete image.
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) override;
|
||||
void eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array jvp(
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
@@ -147,7 +129,8 @@ back and go to our example to give ourselves a more concrete image.
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) override;
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself across
|
||||
@@ -155,7 +138,7 @@ back and go to our example to give ourselves a more concrete image.
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<array, int> vmap(
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
@@ -175,22 +158,22 @@ back and go to our example to give ourselves a more concrete image.
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class and
|
||||
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
|
||||
``beta`` as parameters. It then provides implementations of how the array ``out``
|
||||
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
|
||||
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
|
||||
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
||||
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
|
||||
implementations of how the output array is produced given the inputs through
|
||||
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
|
||||
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
|
||||
:meth:`Axpby::vmap`.
|
||||
|
||||
Using the Primitives
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
Using the Primitive
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Operations can use this :class:`Primitive` to add a new :class:`array` to
|
||||
the computation graph. An :class:`array` can be constructed by providing its
|
||||
data type, shape, the :class:`Primitive` that computes it, and the
|
||||
:class:`array` inputs that are passed to the primitive.
|
||||
Operations can use this :class:`Primitive` to add a new :class:`array` to the
|
||||
computation graph. An :class:`array` can be constructed by providing its data
|
||||
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
|
||||
inputs that are passed to the primitive.
|
||||
|
||||
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
||||
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -223,7 +206,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
@@ -238,27 +221,26 @@ This operation now handles the following:
|
||||
Implementing the Primitive
|
||||
--------------------------
|
||||
|
||||
No computation happens when we call the operation alone. In effect, the
|
||||
operation only builds the computation graph. When we evaluate the output
|
||||
array, MLX schedules the execution of the computation graph, and calls
|
||||
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
|
||||
stream/device specified by the user.
|
||||
No computation happens when we call the operation alone. The operation only
|
||||
builds the computation graph. When we evaluate the output array, MLX schedules
|
||||
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
|
||||
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
|
||||
|
||||
.. warning::
|
||||
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
|
||||
no memory has been allocated for the output array. It falls on the implementation
|
||||
of these functions to allocate memory as needed
|
||||
of these functions to allocate memory as needed.
|
||||
|
||||
Implementing the CPU Backend
|
||||
Implementing the CPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's start by trying to implement a naive and generic version of
|
||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||
Let's start by implementing a naive and generic version of
|
||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||
|
||||
Our naive method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
Our naive method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -296,19 +278,19 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
}
|
||||
}
|
||||
|
||||
Now, we would like our implementation to be able to do this pointwise operation
|
||||
for all incoming floating point arrays. Accordingly, we add dispatches for
|
||||
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
|
||||
if we encounter an unexpected type.
|
||||
Our implementation should work for all incoming floating point arrays.
|
||||
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||
``complex64``. We throw an error if we encounter an unexpected type.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
@@ -321,28 +303,26 @@ if we encounter an unexpected type.
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
"[Axpby] Only supports floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
We have a fallback implementation! Now, to do what we are really here to do.
|
||||
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
|
||||
framework? Well, there are 3 complications to keep in mind:
|
||||
This is good as a fallback implementation. We can use the ``axpby`` routine
|
||||
provided by the Accelerate_ framework for a faster implementation in certain
|
||||
cases:
|
||||
|
||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
||||
floats. We can only direct to it for ``float32`` types
|
||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
|
||||
have fixed strides between them. Possibly due to broadcasts and transposes,
|
||||
we aren't guaranteed that the inputs fit this requirement. We can
|
||||
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
|
||||
column contiguous.
|
||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
|
||||
MLX expects to write out the answer to a new array. We must copy the elements
|
||||
of ``y`` into the output array and use that as an input to ``axpby``
|
||||
floats. We can only use it for ``float32`` types.
|
||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
|
||||
elements have fixed strides between them. We only direct to Accelerate
|
||||
if both ``x`` and ``y`` are row contiguous or column contiguous.
|
||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
|
||||
MLX expects to write the output to a new array. We must copy the elements
|
||||
of ``y`` into the output and use that as an input to ``axpby``.
|
||||
|
||||
Let's write out an implementation that uses Accelerate in the right conditions.
|
||||
It must simply allocate data for the output, copy elements of ``y`` into it,
|
||||
and then call the :meth:`catlas_saxpby` from accelerate.
|
||||
Let's write an implementation that uses Accelerate in the right conditions.
|
||||
It allocates data for the output, copies ``y`` into it, and then calls the
|
||||
:func:`catlas_saxpby` from accelerate.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -356,17 +336,7 @@ and then call the :meth:`catlas_saxpby` from accelerate.
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
|
||||
// This specialization requires both x and y be contiguous in the same mode
|
||||
// i.e: corresponding linear indices in both point to corresponding elements
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
||||
y.data_size(),
|
||||
y.strides(),
|
||||
y.flags());
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
@@ -389,18 +359,20 @@ and then call the :meth:`catlas_saxpby` from accelerate.
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
Great! But what about the inputs that do not fit the criteria for accelerate?
|
||||
Luckily, we can always just direct back to :meth:`Axpby::eval`.
|
||||
|
||||
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
||||
For inputs that do not fit the criteria for accelerate, we fall back to
|
||||
:meth:`Axpby::eval`. With this in mind, let's finish our
|
||||
:meth:`Axpby::eval_cpu`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
@@ -410,35 +382,33 @@ With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, out);
|
||||
// Fall back to common back-end if specializations are not available
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
We have now hit a milestone! Just this much is enough to run the operation
|
||||
:meth:`axpby` on a CPU stream!
|
||||
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
||||
you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||
|
||||
If you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||
|
||||
Implementing the GPU Backend
|
||||
Implementing the GPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
||||
all GPU kernels in MLX are written using metal.
|
||||
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
||||
GPU kernels in MLX are written using Metal.
|
||||
|
||||
.. note::
|
||||
|
||||
Here are some helpful resources if you are new to metal!
|
||||
Here are some helpful resources if you are new to Metal:
|
||||
|
||||
* A walkthrough of the metal compute pipeline: `Metal Example`_
|
||||
* Documentation for metal shading language: `Metal Specification`_
|
||||
* Using metal from C++: `Metal-cpp`_
|
||||
|
||||
Let's keep the GPU algorithm simple. We will launch exactly as many threads
|
||||
as there are elements in the output. Each thread will pick the element it needs
|
||||
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
|
||||
element in the output.
|
||||
Let's keep the GPU kernel simple. We will launch exactly as many threads as
|
||||
there are elements in the output. Each thread will pick the element it needs
|
||||
from ``x`` and ``y``, do the point-wise operation, and update its assigned
|
||||
element in the output.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -457,15 +427,14 @@ element in the output.
|
||||
// Convert linear indices to offsets in array
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||
|
||||
|
||||
// Do the operation and update the output
|
||||
out[index] =
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||
}
|
||||
|
||||
We then need to instantiate this template for all floating point types and give
|
||||
each instantiation a unique host name so we can identify the right kernel for
|
||||
each data type.
|
||||
each instantiation a unique host name so we can identify it.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -488,29 +457,21 @@ each data type.
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
|
||||
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
||||
will see later in :ref:`Building with CMake`. In the following example, we
|
||||
assume that the library ``mlx_ext.metallib`` will always be co-located with
|
||||
the executable/ shared-library calling the :meth:`register_library` function.
|
||||
The :meth:`register_library` function takes the library's name and potential
|
||||
path (or in this case, a function that can produce the path of the metal
|
||||
library) and tries to load that library if it hasn't already been registered
|
||||
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
|
||||
it is important to package your C++ library with the metal library. We will
|
||||
go over this process in more detail later.
|
||||
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions
|
||||
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
below.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Each primitive carries the stream it should execute on
|
||||
// and each stream carries its device identifiers
|
||||
@@ -518,10 +479,10 @@ below.
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
// Resolve name of kernel
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
@@ -552,7 +513,7 @@ below.
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
@@ -575,28 +536,25 @@ below.
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
|
||||
A few things to note about MLX and metal before moving on. MLX keeps track
|
||||
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
|
||||
to give us the active metal compute command encoder instead of building a
|
||||
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
||||
until some specified limit is hit or the compute encoder needs to be flushed
|
||||
for synchronization. MLX also handles enqueuing and committing the associated
|
||||
command buffers as needed. We suggest taking a deeper dive into
|
||||
:class:`metal::Device` if you would like to study this routine further.
|
||||
A few things to note about MLX and Metal before moving on. MLX keeps track of
|
||||
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
|
||||
associated. We rely on :meth:`d.get_command_encoder` to give us the active
|
||||
metal compute command encoder instead of building a new one and calling
|
||||
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
|
||||
pipelines) to the active command buffer until some specified limit is hit or
|
||||
the command buffer needs to be flushed for synchronization.
|
||||
|
||||
Primitive Transforms
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Now that we have come this far, let's also learn how to add implementations to
|
||||
transformations in a :class:`Primitive`. These transformations can be built on
|
||||
top of our operations, including the one we just defined now. Which then gives
|
||||
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
Next, let's add implementations for transformations in a :class:`Primitive`.
|
||||
These transformations can be built on top of other operations, including the
|
||||
one we just defined:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array Axpby::jvp(
|
||||
std::vector<array> Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
@@ -611,12 +569,12 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return multiply(scale_arr, tangents[0], stream());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -625,34 +583,35 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<int>& /* unused */) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotan.dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotan, stream()));
|
||||
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
Finally, you need not have a transformation fully defined to start using your
|
||||
own :class:`Primitive`.
|
||||
Note, a transformation does not need to be fully defined to start using
|
||||
the :class:`Primitive`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<array, int> Axpby::vmap(
|
||||
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
throw std::runtime_error("[Axpby] vmap not implemented.");
|
||||
}
|
||||
|
||||
Building and Binding
|
||||
--------------------
|
||||
|
||||
Let's look at the overall directory structure first.
|
||||
Let's look at the overall directory structure first.
|
||||
|
||||
| extensions
|
||||
| ├── axpby
|
||||
@@ -666,40 +625,39 @@ Let's look at the overall directory structure first.
|
||||
| └── setup.py
|
||||
|
||||
* ``extensions/axpby/`` defines the C++ extension library
|
||||
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||
associated python package
|
||||
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||
python bindings
|
||||
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||
associated Python package
|
||||
* ``extensions/bindings.cpp`` provides Python bindings for our operation
|
||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||
Python bindings
|
||||
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
|
||||
the python package
|
||||
the Python package
|
||||
|
||||
Binding to Python
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
|
||||
We use nanobind_ to build a Python API for the C++ library. Since bindings for
|
||||
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
|
||||
already provided, adding our :meth:`axpby` is simple!
|
||||
already provided, adding our :meth:`axpby` is simple.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
py::pos_only(),
|
||||
"alpha"_a,
|
||||
"beta"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
R"(
|
||||
Scale and sum two vectors element-wise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
Inputs are upcasted to floats if needed
|
||||
|
||||
@@ -711,17 +669,17 @@ already provided, adding our :meth:`axpby` is simple!
|
||||
|
||||
Returns:
|
||||
array: ``alpha * x + beta * y``
|
||||
)pbdoc");
|
||||
)");
|
||||
}
|
||||
|
||||
Most of the complexity in the above example comes from additional bells and
|
||||
Most of the complexity in the above example comes from additional bells and
|
||||
whistles such as the literal names and doc-strings.
|
||||
|
||||
.. warning::
|
||||
|
||||
:mod:`mlx.core` needs to be imported before importing
|
||||
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
|
||||
ensure that the casters for :mod:`mlx.core` components like
|
||||
:mod:`mlx.core` must be imported before importing
|
||||
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
|
||||
ensure that the casters for :mod:`mlx.core` components like
|
||||
:class:`mlx.core.array` are available.
|
||||
|
||||
.. _Building with CMake:
|
||||
@@ -729,8 +687,8 @@ whistles such as the literal names and doc-strings.
|
||||
Building with CMake
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Building the C++ extension library itself is simple, it only requires that you
|
||||
``find_package(MLX CONFIG)`` and then link it to your library.
|
||||
Building the C++ extension library only requires that you ``find_package(MLX
|
||||
CONFIG)`` and then link it to your library.
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
@@ -752,12 +710,12 @@ Building the C++ extension library itself is simple, it only requires that you
|
||||
# Link to mlx
|
||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||
|
||||
We also need to build the attached metal library. For convenience, we provide a
|
||||
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
||||
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
||||
automatically imported with MLX package).
|
||||
We also need to build the attached Metal library. For convenience, we provide a
|
||||
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
||||
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
||||
automatically imported with MLX package).
|
||||
|
||||
Here is what that looks like in practice!
|
||||
Here is what that looks like in practice:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
@@ -779,27 +737,29 @@ Here is what that looks like in practice!
|
||||
|
||||
endif()
|
||||
|
||||
Finally, we build the Pybind11_ bindings
|
||||
Finally, we build the nanobind_ bindings
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
pybind11_add_module(
|
||||
mlx_sample_extensions
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
nanobind_add_module(
|
||||
_ext
|
||||
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||
NB_DOMAIN mlx
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
)
|
||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
|
||||
endif()
|
||||
|
||||
Building with ``setuptools``
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Once we have set out the CMake build rules as described above, we can use the
|
||||
build utilities defined in :mod:`mlx.extension` for a simple build process.
|
||||
build utilities defined in :mod:`mlx.extension`:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
from mlx import extension
|
||||
from setuptools import setup
|
||||
@@ -809,48 +769,50 @@ build utilities defined in :mod:`mlx.extension` for a simple build process.
|
||||
name="mlx_sample_extensions",
|
||||
version="0.0.0",
|
||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages = ["mlx_sample_extensions"],
|
||||
package_dir = {"": "mlx_sample_extensions"},
|
||||
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
|
||||
packages=["mlx_sample_extensions"],
|
||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||
extras_require={"dev":[]},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.7",
|
||||
python_requires=">=3.8",
|
||||
)
|
||||
|
||||
.. note::
|
||||
We treat ``extensions/mlx_sample_extensions`` as the package directory
|
||||
even though it only contains a ``__init__.py`` to ensure the following:
|
||||
|
||||
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
|
||||
* The C++ extension library and the metal library are co-located with the python
|
||||
bindings and copied together if the package is installed
|
||||
|
||||
You can build inplace for development using
|
||||
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
|
||||
* The C++ extension library and the metal library are co-located with the python
|
||||
bindings and copied together if the package is installed
|
||||
|
||||
To build the package, first install the build dependencies with ``pip install
|
||||
-r requirements.txt``. You can then build inplace for development using
|
||||
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
|
||||
|
||||
This will result in a directory structure as follows:
|
||||
This results in the directory structure:
|
||||
|
||||
| extensions
|
||||
| ├── mlx_sample_extensions
|
||||
| │ ├── __init__.py
|
||||
| │ ├── libmlx_ext.dylib # C++ extension library
|
||||
| │ ├── mlx_ext.metallib # Metal library
|
||||
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
|
||||
| │ └── _ext.cpython-3x-darwin.so # Python Binding
|
||||
| ...
|
||||
|
||||
When you try to install using the command ``python -m pip install .``
|
||||
(in ``extensions/``), the package will be installed with the same structure as
|
||||
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
||||
copied along with the python binding since they are specified as ``package_data``.
|
||||
When you try to install using the command ``python -m pip install .`` (in
|
||||
``extensions/``), the package will be installed with the same structure as
|
||||
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
|
||||
copied along with the Python binding since they are specified as
|
||||
``package_data``.
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
After installing the extension as described above, you should be able to simply
|
||||
import the python package and play with it as you would any other MLX operation!
|
||||
After installing the extension as described above, you should be able to simply
|
||||
import the Python package and play with it as you would any other MLX operation.
|
||||
|
||||
Let's looks at a simple script and it's results!
|
||||
Let's look at a simple script and its results:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -874,12 +836,12 @@ Output:
|
||||
c correctness: True
|
||||
|
||||
Results
|
||||
^^^^^^^^^^^^^^^^
|
||||
^^^^^^^
|
||||
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_sample_extensions import axpby
|
||||
@@ -898,7 +860,7 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
||||
alpha = 4.0
|
||||
beta = 2.0
|
||||
|
||||
mx.eval((x, y))
|
||||
mx.eval(x, y)
|
||||
|
||||
def bench(f):
|
||||
# Warm up
|
||||
@@ -919,30 +881,23 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
||||
|
||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
||||
|
||||
Results:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Simple axpby: 0.114 s | Custom axpby: 0.109 s
|
||||
|
||||
We see some modest improvements right away!
|
||||
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
|
||||
modest improvements right away!
|
||||
|
||||
This operation is now good to be used to build other operations, in
|
||||
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
|
||||
:meth:`grad`!
|
||||
:meth:`grad`.
|
||||
|
||||
Scripts
|
||||
-------
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx <code>`_.
|
||||
|
||||
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
|
||||
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
|
||||
|
||||
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
||||
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
||||
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
||||
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
||||
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
|
||||
.. _nanobind: https://nanobind.readthedocs.io/en/latest/
|
||||
|
69
docs/src/dev/metal_debugger.rst
Normal file
69
docs/src/dev/metal_debugger.rst
Normal file
@@ -0,0 +1,69 @@
|
||||
Metal Debugger
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Profiling is a key step for performance optimization. You can build MLX with
|
||||
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
|
||||
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
|
||||
|
||||
* Records source during Metal compilation, for later inspection while
|
||||
debugging.
|
||||
* Labels Metal objects such as command queues, improving capture readability.
|
||||
|
||||
To build with debugging enabled in Python prepend
|
||||
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
|
||||
|
||||
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
|
||||
work.
|
||||
|
||||
.. note::
|
||||
|
||||
To capture a GPU trace you must run the application with
|
||||
``MTL_CAPTURE_ENABLED=1``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
a = mx.random.uniform(shape=(512, 512))
|
||||
b = mx.random.uniform(shape=(512, 512))
|
||||
mx.eval(a, b)
|
||||
|
||||
trace_file = "mlx_trace.gputrace"
|
||||
|
||||
if not mx.metal.start_capture(trace_file):
|
||||
print("Make sure to run with MTL_CAPTURE_ENABLED=1 and "
|
||||
f"that the path {trace_file} does not already exist.")
|
||||
exit(1)
|
||||
|
||||
for _ in range(10):
|
||||
mx.eval(mx.add(a, b))
|
||||
|
||||
mx.metal.stop_capture()
|
||||
|
||||
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
|
||||
has a great overview of all operations. Checkout the `Metal debugger
|
||||
documentation`_ for more information.
|
||||
|
||||
.. image:: ../_static/metal_debugger/capture.png
|
||||
:class: dark-light
|
||||
|
||||
Xcode Workflow
|
||||
--------------
|
||||
|
||||
You can skip saving to a path by running within Xcode. First, generate an
|
||||
Xcode project using CMake.
|
||||
|
||||
.. code-block::
|
||||
|
||||
mkdir build && cd build
|
||||
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
|
||||
open mlx.xcodeproj
|
||||
|
||||
Select the ``metal_capture`` example schema and run.
|
||||
|
||||
.. image:: ../_static/metal_debugger/schema.png
|
||||
:class: dark-light
|
||||
|
||||
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger
|
@@ -58,10 +58,12 @@ are the CPU and GPU.
|
||||
:maxdepth: 1
|
||||
|
||||
python/array
|
||||
python/data_types
|
||||
python/devices_and_streams
|
||||
python/ops
|
||||
python/random
|
||||
python/transforms
|
||||
python/fast
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
@@ -80,3 +82,4 @@ are the CPU and GPU.
|
||||
:maxdepth: 1
|
||||
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
|
@@ -15,10 +15,10 @@ To install from PyPI you must meet the following requirements:
|
||||
|
||||
- Using an M series chip (Apple silicon)
|
||||
- Using a native Python >= 3.8
|
||||
- macOS >= 13.3
|
||||
- macOS >= 13.5
|
||||
|
||||
.. note::
|
||||
MLX is only available on devices running macOS >= 13.3
|
||||
MLX is only available on devices running macOS >= 13.5
|
||||
It is highly recommended to use macOS 14 (Sonoma)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ Build Requirements
|
||||
|
||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
||||
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
|
||||
- Xcode >= 15.0 and macOS SDK >= 14.0
|
||||
|
||||
.. note::
|
||||
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
|
||||
@@ -70,16 +70,13 @@ To build and install the MLX python library from source, first, clone MLX from
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
|
||||
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
|
||||
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install "pybind11[global]"
|
||||
conda install pybind11
|
||||
brew install pybind11
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
|
||||
Then simply build and install it using pip:
|
||||
Then simply build and install MLX using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
@@ -123,7 +120,7 @@ Create a build directory and run CMake and make:
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. && make -j
|
||||
cmake .. && make -j
|
||||
|
||||
Run tests with:
|
||||
|
||||
@@ -142,7 +139,7 @@ directory as the executable statically linked to ``libmlx.a`` or the
|
||||
preprocessor constant ``METAL_PATH`` should be defined at build time and it
|
||||
should point to the path to the built metal library.
|
||||
|
||||
.. list-table:: Build Options
|
||||
.. list-table:: Build Options
|
||||
:widths: 25 8
|
||||
:header-rows: 1
|
||||
|
||||
@@ -158,19 +155,21 @@ should point to the path to the built metal library.
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
||||
* - MLX_METAL_DEBUG
|
||||
- OFF
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
If you have multiple Xcode installations and wish to use
|
||||
a specific one while building, you can do so by adding the
|
||||
following environment variable before building
|
||||
If you have multiple Xcode installations and wish to use
|
||||
a specific one while building, you can do so by adding the
|
||||
following environment variable before building
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
|
||||
|
||||
Further, you can use the following command to find out which
|
||||
Further, you can use the following command to find out which
|
||||
macOS SDK will be used
|
||||
|
||||
.. code-block:: shell
|
||||
@@ -202,7 +201,7 @@ Then set the active developer directory:
|
||||
|
||||
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
|
||||
|
||||
x86 Shell
|
||||
x86 Shell
|
||||
~~~~~~~~~
|
||||
|
||||
.. _build shell:
|
||||
|
@@ -10,27 +10,38 @@ Array
|
||||
|
||||
array
|
||||
array.astype
|
||||
array.at
|
||||
array.item
|
||||
array.tolist
|
||||
array.dtype
|
||||
array.itemsize
|
||||
array.nbytes
|
||||
array.ndim
|
||||
array.shape
|
||||
array.size
|
||||
Dtype
|
||||
array.abs
|
||||
array.all
|
||||
array.any
|
||||
array.argmax
|
||||
array.argmin
|
||||
array.cos
|
||||
array.dtype
|
||||
array.cummax
|
||||
array.cummin
|
||||
array.cumprod
|
||||
array.cumsum
|
||||
array.diag
|
||||
array.diagonal
|
||||
array.exp
|
||||
array.flatten
|
||||
array.log
|
||||
array.log10
|
||||
array.log1p
|
||||
array.log2
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
array.min
|
||||
array.moveaxis
|
||||
array.prod
|
||||
array.reciprocal
|
||||
array.reshape
|
||||
@@ -40,6 +51,8 @@ Array
|
||||
array.split
|
||||
array.sqrt
|
||||
array.square
|
||||
array.squeeze
|
||||
array.swapaxes
|
||||
array.sum
|
||||
array.transpose
|
||||
array.T
|
||||
|
@@ -1,7 +1,5 @@
|
||||
.. _data_types:
|
||||
|
||||
:orphan:
|
||||
|
||||
Data Types
|
||||
==========
|
||||
|
||||
@@ -44,9 +42,27 @@ The default floating point type is ``float32`` and the default integer type is
|
||||
* - ``int64``
|
||||
- 8
|
||||
- 64-bit signed integer
|
||||
* - ``bfloat16``
|
||||
- 2
|
||||
- 16-bit brain float (e8, m7)
|
||||
* - ``float16``
|
||||
- 2
|
||||
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
|
||||
- 16-bit IEEE float (e5, m10)
|
||||
* - ``float32``
|
||||
- 4
|
||||
- 32-bit float
|
||||
* - ``complex64``
|
||||
- 8
|
||||
- 64-bit complex float
|
||||
|
||||
|
||||
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||
``dtype`` (or category) is a subtype of another category.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Dtype
|
||||
DtypeCategory
|
||||
issubdtype
|
||||
|
14
docs/src/python/fast.rst
Normal file
14
docs/src/python/fast.rst
Normal file
@@ -0,0 +1,14 @@
|
||||
.. _fast:
|
||||
|
||||
Fast
|
||||
====
|
||||
|
||||
.. currentmodule:: mlx.core.fast
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
rms_norm
|
||||
layer_norm
|
||||
rope
|
||||
scaled_dot_product_attention
|
@@ -3,7 +3,7 @@ Metal
|
||||
|
||||
.. currentmodule:: mlx.core.metal
|
||||
|
||||
.. autosummary::
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
is_available
|
||||
@@ -12,3 +12,5 @@ Metal
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
start_capture
|
||||
stop_capture
|
||||
|
@@ -173,7 +173,7 @@ In detail:
|
||||
:toctree: _autosummary
|
||||
|
||||
value_and_grad
|
||||
checkpoint
|
||||
quantize
|
||||
|
||||
.. toctree::
|
||||
|
||||
|
@@ -21,17 +21,21 @@ Layers
|
||||
Embedding
|
||||
GELU
|
||||
GroupNorm
|
||||
GRU
|
||||
InstanceNorm
|
||||
LayerNorm
|
||||
Linear
|
||||
LSTM
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
Mish
|
||||
MultiHeadAttention
|
||||
PReLU
|
||||
QuantizedEmbedding
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
RNN
|
||||
RoPE
|
||||
SELU
|
||||
Sequential
|
||||
@@ -40,4 +44,4 @@ Layers
|
||||
Softshrink
|
||||
Step
|
||||
Transformer
|
||||
Upsample
|
||||
Upsample
|
||||
|
@@ -30,6 +30,7 @@ Module
|
||||
Module.named_modules
|
||||
Module.parameters
|
||||
Module.save_weights
|
||||
Module.set_dtype
|
||||
Module.train
|
||||
Module.trainable_parameters
|
||||
Module.unfreeze
|
||||
|
@@ -5,13 +5,13 @@ Operations
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
abs
|
||||
add
|
||||
all
|
||||
allclose
|
||||
allclose
|
||||
any
|
||||
arange
|
||||
arccos
|
||||
@@ -29,6 +29,7 @@ Operations
|
||||
atleast_2d
|
||||
atleast_3d
|
||||
broadcast_to
|
||||
block_masked_mm
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
@@ -38,6 +39,10 @@ Operations
|
||||
conv_general
|
||||
cos
|
||||
cosh
|
||||
cummax
|
||||
cummin
|
||||
cumprod
|
||||
cumsum
|
||||
dequantize
|
||||
diag
|
||||
diagonal
|
||||
@@ -47,6 +52,7 @@ Operations
|
||||
erf
|
||||
erfinv
|
||||
exp
|
||||
expm1
|
||||
expand_dims
|
||||
eye
|
||||
flatten
|
||||
@@ -58,10 +64,10 @@ Operations
|
||||
identity
|
||||
inner
|
||||
isclose
|
||||
isnan
|
||||
isposinf
|
||||
isneginf
|
||||
isinf
|
||||
isnan
|
||||
isneginf
|
||||
isposinf
|
||||
less
|
||||
less_equal
|
||||
linspace
|
||||
@@ -79,6 +85,7 @@ Operations
|
||||
max
|
||||
maximum
|
||||
mean
|
||||
meshgrid
|
||||
min
|
||||
minimum
|
||||
moveaxis
|
||||
@@ -113,6 +120,7 @@ Operations
|
||||
square
|
||||
squeeze
|
||||
stack
|
||||
std
|
||||
stop_gradient
|
||||
subtract
|
||||
sum
|
||||
|
@@ -38,6 +38,7 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||
gumbel
|
||||
key
|
||||
normal
|
||||
multivariate_normal
|
||||
randint
|
||||
seed
|
||||
split
|
||||
|
@@ -17,4 +17,3 @@ Transforms
|
||||
jvp
|
||||
vjp
|
||||
vmap
|
||||
checkpoint
|
||||
|
@@ -19,3 +19,4 @@ return python trees will be using the default python ``dict``, ``list`` and
|
||||
tree_flatten
|
||||
tree_unflatten
|
||||
tree_map
|
||||
tree_map_with_path
|
||||
|
@@ -40,7 +40,7 @@ getting higher order derivatives.
|
||||
|
||||
Any of the MLX function transformations can be composed in any order to any
|
||||
depth. See the following sections for more information on :ref:`automatic
|
||||
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
|
||||
differentiation <auto diff>` and :ref:`automatic vectorization <vmap>`.
|
||||
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
|
||||
|
||||
|
||||
|
@@ -18,7 +18,7 @@ describe below.
|
||||
Transforming Compute Graphs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Lazy evaluation let's us record a compute graph without actually doing any
|
||||
Lazy evaluation lets us record a compute graph without actually doing any
|
||||
computations. This is useful for function transformations like :func:`grad` and
|
||||
:func:`vmap` and graph optimizations.
|
||||
|
||||
|
@@ -49,7 +49,7 @@ it will be added. You can load the array with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> mx.load("array.npy", a)
|
||||
>>> mx.load("array.npy")
|
||||
array([1], dtype=float32)
|
||||
|
||||
Here's an example of saving several arrays to a single file:
|
||||
|
@@ -8,3 +8,4 @@ endfunction(build_example)
|
||||
build_example(tutorial.cpp)
|
||||
build_example(linear_regression.cpp)
|
||||
build_example(logistic_regression.cpp)
|
||||
build_example(metal_capture.cpp)
|
||||
|
31
examples/cpp/metal_capture.cpp
Normal file
31
examples/cpp/metal_capture.cpp
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
// To use Metal debugging and profiling:
|
||||
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
||||
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
||||
assert(metal::start_capture("mlx_trace.gputrace"));
|
||||
|
||||
// Start at index two because the default GPU and CPU streams have indices
|
||||
// zero and one, respectively. This naming matches the label assigned to each
|
||||
// stream's command queue.
|
||||
auto s2 = new_stream(Device::gpu);
|
||||
auto s3 = new_stream(Device::gpu);
|
||||
|
||||
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
||||
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
||||
auto x = add(a, a, s2);
|
||||
auto y = add(b, b, s3);
|
||||
|
||||
// The multiply will happen on the default stream.
|
||||
std::cout << multiply(x, y) << std::endl;
|
||||
|
||||
metal::stop_capture();
|
||||
}
|
@@ -1,6 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(mlx_sample_extensions LANGUAGES CXX)
|
||||
project(_ext LANGUAGES CXX)
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
@@ -11,8 +11,12 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
||||
|
||||
# ----------------------------- Dependencies -----------------------------
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
find_package(Python COMPONENTS Interpreter Development)
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
# ----------------------------- Extensions -----------------------------
|
||||
|
||||
@@ -38,7 +42,6 @@ target_link_libraries(mlx_ext PUBLIC mlx)
|
||||
|
||||
# Build metallib
|
||||
if(MLX_BUILD_METAL)
|
||||
|
||||
mlx_build_metallib(
|
||||
TARGET mlx_ext_metallib
|
||||
TITLE mlx_ext
|
||||
@@ -54,13 +57,15 @@ if(MLX_BUILD_METAL)
|
||||
|
||||
endif()
|
||||
|
||||
# ----------------------------- Pybind -----------------------------
|
||||
pybind11_add_module(
|
||||
mlx_sample_extensions
|
||||
# ----------------------------- Python Bindings -----------------------------
|
||||
nanobind_add_module(
|
||||
_ext
|
||||
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||
NB_DOMAIN mlx
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
)
|
||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
|
||||
endif()
|
||||
|
18
examples/extensions/README.md
Normal file
18
examples/extensions/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
## Build the extensions
|
||||
|
||||
```
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For faster builds during development, you can also pre-install the requirements:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And then run:
|
||||
|
||||
```
|
||||
python setup.py build_ext -j8 --inplace
|
||||
```
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
@@ -43,7 +43,7 @@ array axpby(
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = is_floating_point(promoted_dtype)
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
|
||||
@@ -61,7 +61,7 @@ array axpby(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
@@ -106,12 +106,12 @@ void axpby_impl(
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out_arr) {
|
||||
auto out = out_arr[0];
|
||||
std::vector<array>& outputs) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
@@ -150,11 +150,7 @@ void axpby_impl_accelerate(
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
||||
y.data_size(),
|
||||
y.strides(),
|
||||
y.flags());
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
@@ -180,11 +176,11 @@ void axpby_impl_accelerate(
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outarr) {
|
||||
auto out = outarr[0];
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
@@ -195,7 +191,7 @@ void Axpby::eval_cpu(
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, outarr);
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#else // Accelerate not available
|
||||
@@ -203,8 +199,8 @@ void Axpby::eval_cpu(
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out) {
|
||||
eval(inputs, out);
|
||||
const std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -218,12 +214,12 @@ void Axpby::eval_cpu(
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outarr) {
|
||||
std::vector<array>& outputs) {
|
||||
// Prepare inputs
|
||||
auto out = outarr[0];
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Each primitive carries the stream it should execute on
|
||||
// and each stream carries its device identifiers
|
||||
@@ -372,4 +368,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
|
||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
@@ -42,9 +42,9 @@ class Axpby : public Primitive {
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out)
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out)
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
@@ -83,7 +83,7 @@ class Axpby : public Primitive {
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& out);
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,31 +1,31 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
|
||||
#include "axpby/axpby.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
py::pos_only(),
|
||||
"alpha"_a,
|
||||
"beta"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
R"(
|
||||
Scale and sum two vectors element-wise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
Inputs are upcasted to floats if needed
|
||||
|
||||
@@ -37,5 +37,5 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
|
||||
Returns:
|
||||
array: ``alpha * x + beta * y``
|
||||
)pbdoc");
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
@@ -1,3 +1,8 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.24",
|
||||
"mlx>=0.9.0",
|
||||
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
4
examples/extensions/requirements.txt
Normal file
4
examples/extensions/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
mlx>=0.9.0
|
||||
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
@@ -9,11 +9,11 @@ if __name__ == "__main__":
|
||||
name="mlx_sample_extensions",
|
||||
version="0.0.0",
|
||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages=["mlx_sample_extensions"],
|
||||
package_dir={"": "."},
|
||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||
extras_require={"dev": []},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.8",
|
||||
)
|
||||
|
138
mlx/array.cpp
138
mlx/array.cpp
@@ -12,16 +12,6 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
||||
std::vector<size_t> strides(shape.size());
|
||||
size_t cum_prod = 1;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
strides[i] = cum_prod;
|
||||
cum_prod *= shape[i];
|
||||
}
|
||||
return {cum_prod, strides};
|
||||
}
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
@@ -36,22 +26,11 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
init(&cval);
|
||||
}
|
||||
|
||||
array::array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
shape,
|
||||
dtype,
|
||||
std::move(primitive),
|
||||
inputs)) {}
|
||||
|
||||
array::array(
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs)
|
||||
std::vector<array> inputs)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::move(shape),
|
||||
dtype,
|
||||
@@ -59,15 +38,16 @@ array::array(
|
||||
std::move(inputs))) {}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
std::vector<std::vector<int>> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<array> outputs;
|
||||
for (int i = 0; i < shapes.size(); ++i) {
|
||||
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
|
||||
}
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
// For each node in |outputs|, its siblings are the other nodes.
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
auto siblings = outputs;
|
||||
siblings.erase(siblings.begin() + i);
|
||||
outputs[i].set_siblings(std::move(siblings), i);
|
||||
@@ -92,10 +72,10 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
set_data(data, deleter);
|
||||
}
|
||||
|
||||
@@ -104,18 +84,22 @@ void array::detach() {
|
||||
s.array_desc_->inputs.clear();
|
||||
s.array_desc_->siblings.clear();
|
||||
s.array_desc_->position = 0;
|
||||
s.array_desc_->depth = 0;
|
||||
s.array_desc_->primitive = nullptr;
|
||||
}
|
||||
array_desc_->inputs.clear();
|
||||
array_desc_->siblings.clear();
|
||||
array_desc_->position = 0;
|
||||
array_desc_->depth = 0;
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
void array::eval() {
|
||||
mlx::core::eval({*this});
|
||||
// Ensure the array is ready to be read
|
||||
if (status() == Status::scheduled) {
|
||||
event().wait();
|
||||
set_status(Status::available);
|
||||
} else if (status() == Status::unscheduled) {
|
||||
mlx::core::eval({*this});
|
||||
}
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
@@ -164,51 +148,83 @@ void array::copy_shared_buffer(const array& other) {
|
||||
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(array other) {
|
||||
void array::move_shared_buffer(
|
||||
array other,
|
||||
const std::vector<size_t>& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
array_desc_->data = std::move(other.array_desc_->data);
|
||||
array_desc_->strides = other.strides();
|
||||
array_desc_->flags = other.flags();
|
||||
array_desc_->data_size = other.data_size();
|
||||
array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
array_desc_->data_size = data_size;
|
||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||
array_desc_->data_ptr = static_cast<void*>(
|
||||
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
||||
: shape(shape), dtype(dtype) {
|
||||
std::tie(size, strides) = cum_prod(shape);
|
||||
void array::move_shared_buffer(array other) {
|
||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs)
|
||||
: shape(shape),
|
||||
dtype(dtype),
|
||||
primitive(std::move(primitive)),
|
||||
inputs(inputs) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
depth = std::max(in.graph_depth(), depth);
|
||||
void array::ArrayDesc::init() {
|
||||
strides.resize(shape.size());
|
||||
size = 1;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
strides[i] = size;
|
||||
size *= shape[i];
|
||||
}
|
||||
depth++;
|
||||
for (auto& in : inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
std::vector<int>&& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs)
|
||||
std::vector<array> inputs)
|
||||
: shape(std::move(shape)),
|
||||
dtype(dtype),
|
||||
status(Status::unscheduled),
|
||||
primitive(std::move(primitive)),
|
||||
inputs(std::move(inputs)) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
depth = std::max(in.graph_depth(), depth);
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayDesc::~ArrayDesc() {
|
||||
// When an array description is destroyed it will delete a bunch of arrays
|
||||
// that may also destory their corresponding descriptions and so on and so
|
||||
// forth.
|
||||
//
|
||||
// This calls recursively the destructor and can result in stack overflow, we
|
||||
// instead put them in a vector and destroy them one at a time resulting in a
|
||||
// max stack depth of 2.
|
||||
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
|
||||
|
||||
for (array& a : inputs) {
|
||||
if (a.array_desc_.use_count() == 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
}
|
||||
|
||||
while (!for_deletion.empty()) {
|
||||
// top is going to be deleted at the end of the block *after* the arrays
|
||||
// with inputs have been moved into the vector
|
||||
auto top = std::move(for_deletion.back());
|
||||
for_deletion.pop_back();
|
||||
|
||||
for (array& a : top->inputs) {
|
||||
if (a.array_desc_.use_count() == 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
}
|
||||
}
|
||||
depth++;
|
||||
}
|
||||
|
||||
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
|
119
mlx/array.h
119
mlx/array.h
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
@@ -8,6 +9,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/event.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -31,7 +33,7 @@ class array {
|
||||
template <typename It>
|
||||
array(
|
||||
It data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype =
|
||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||
|
||||
@@ -47,13 +49,13 @@ class array {
|
||||
template <typename T>
|
||||
array(
|
||||
std::initializer_list<T> data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a buffer */
|
||||
array(
|
||||
allocator::Buffer data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter = allocator::free);
|
||||
|
||||
@@ -172,22 +174,16 @@ class array {
|
||||
* API may change.
|
||||
*/
|
||||
|
||||
array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
array(
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs);
|
||||
std::vector<array> inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
std::vector<std::vector<int>> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
/** A unique identifier for an array. */
|
||||
@@ -261,6 +257,17 @@ class array {
|
||||
array_desc_->position = position;
|
||||
}
|
||||
|
||||
/** The i-th output of the array's primitive. */
|
||||
const array& output(int i) const {
|
||||
if (i == array_desc_->position) {
|
||||
return *this;
|
||||
} else if (i < array_desc_->position) {
|
||||
return siblings()[i];
|
||||
} else {
|
||||
return siblings()[i + 1];
|
||||
}
|
||||
};
|
||||
|
||||
/** The outputs of the array's primitive (i.e. this array and
|
||||
* its siblings) in the order the primitive expects. */
|
||||
std::vector<array> outputs() const {
|
||||
@@ -273,11 +280,6 @@ class array {
|
||||
return outputs;
|
||||
};
|
||||
|
||||
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
|
||||
uint16_t graph_depth() const {
|
||||
return array_desc_->depth;
|
||||
}
|
||||
|
||||
/** Detach the array from the graph. */
|
||||
void detach();
|
||||
|
||||
@@ -314,9 +316,27 @@ class array {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
};
|
||||
|
||||
// Check if the array has been evaluated
|
||||
bool is_evaled() const {
|
||||
return array_desc_->data != nullptr;
|
||||
enum Status { unscheduled, scheduled, available };
|
||||
|
||||
bool is_available() const {
|
||||
return status() == Status::available;
|
||||
}
|
||||
const Status status() const {
|
||||
return array_desc_->status;
|
||||
}
|
||||
|
||||
void set_status(Status s) const {
|
||||
array_desc_->status = s;
|
||||
}
|
||||
|
||||
// Get the array's shared event
|
||||
Event& event() const {
|
||||
return array_desc_->event;
|
||||
}
|
||||
|
||||
// Attach an event to a not yet evaluated array
|
||||
void attach_event(Event e) const {
|
||||
array_desc_->event = std::move(e);
|
||||
}
|
||||
|
||||
// Mark the array as a tracer array (true) or not.
|
||||
@@ -344,6 +364,13 @@ class array {
|
||||
|
||||
void copy_shared_buffer(const array& other);
|
||||
|
||||
void move_shared_buffer(
|
||||
array other,
|
||||
const std::vector<size_t>& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
void move_shared_buffer(array other);
|
||||
|
||||
void overwrite_descriptor(const array& other) {
|
||||
@@ -360,7 +387,12 @@ class array {
|
||||
std::vector<size_t> strides;
|
||||
size_t size;
|
||||
Dtype dtype;
|
||||
std::shared_ptr<Primitive> primitive{nullptr};
|
||||
std::shared_ptr<Primitive> primitive;
|
||||
|
||||
Status status;
|
||||
|
||||
// An event on the array used for synchronization
|
||||
Event event;
|
||||
|
||||
// Indicates an array is being used in a graph transform
|
||||
// and should not be detached from the graph
|
||||
@@ -368,7 +400,7 @@ class array {
|
||||
|
||||
// This is a shared pointer so that *different* arrays
|
||||
// can share the underlying data buffer.
|
||||
std::shared_ptr<Data> data{nullptr};
|
||||
std::shared_ptr<Data> data;
|
||||
|
||||
// Properly offset data pointer
|
||||
void* data_ptr{nullptr};
|
||||
@@ -388,29 +420,26 @@ class array {
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
// The depth of the array in the graph.
|
||||
uint16_t depth{0};
|
||||
|
||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
std::vector<array> inputs);
|
||||
|
||||
explicit ArrayDesc(
|
||||
std::vector<int>&& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs);
|
||||
~ArrayDesc();
|
||||
|
||||
private:
|
||||
// Initialize size, strides, and other metadata
|
||||
void init();
|
||||
};
|
||||
|
||||
// The ArrayDesc contains the details of the materialized array including the
|
||||
// shape, strides, the data type. It also includes
|
||||
// the primitive which knows how to compute the array's data from its inputs
|
||||
// and the list of array's inputs for the primitive.
|
||||
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
|
||||
std::shared_ptr<ArrayDesc> array_desc_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -422,9 +451,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||
template <typename It>
|
||||
array::array(
|
||||
It data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
||||
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
init(data);
|
||||
}
|
||||
|
||||
@@ -441,9 +470,9 @@ array::array(
|
||||
template <typename T>
|
||||
array::array(
|
||||
std::initializer_list<T> data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
if (data.size() != size()) {
|
||||
throw std::invalid_argument(
|
||||
"Data size and provided shape mismatch in array construction.");
|
||||
@@ -465,10 +494,11 @@ T array::item() const {
|
||||
if (size() != 1) {
|
||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||
}
|
||||
if (!is_evaled()) {
|
||||
if (status() == Status::unscheduled) {
|
||||
throw std::invalid_argument(
|
||||
"item() const can only be called on evaled arrays");
|
||||
}
|
||||
const_cast<array*>(this)->eval();
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
@@ -518,4 +548,15 @@ void array::init(It src) {
|
||||
}
|
||||
}
|
||||
|
||||
/* Utilities for determining whether a template parameter is array. */
|
||||
template <typename T>
|
||||
inline constexpr bool is_array_v =
|
||||
std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
|
||||
|
||||
template <typename... T>
|
||||
inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
|
||||
|
||||
template <typename... T>
|
||||
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
@@ -196,6 +196,40 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
||||
return matmul_bnns_general(a_pre, b_pre, out);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void mask_matrix(
|
||||
T* data,
|
||||
const bool* mask,
|
||||
int tile_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str) {
|
||||
int tX = (X + tile_size - 1) / tile_size;
|
||||
int tY = (Y + tile_size - 1) / tile_size;
|
||||
|
||||
for (int i = 0; i < tX; i++) {
|
||||
for (int j = 0; j < tY; j++) {
|
||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
||||
if (!do_mask) {
|
||||
int loc_x = i * tile_size;
|
||||
int loc_y = j * tile_size;
|
||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||
|
||||
int size_x = std::min(tile_size, X - loc_x);
|
||||
int size_y = std::min(tile_size, Y - loc_y);
|
||||
for (int ii = 0; ii < size_x; ii++) {
|
||||
for (int jj = 0; jj < size_y; jj++) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
@@ -31,6 +31,7 @@ DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
@@ -38,6 +39,7 @@ DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(NumberOfElements)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
@@ -68,10 +70,13 @@ DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
@@ -297,7 +302,7 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (is_floating_point(out.dtype())) {
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -306,6 +311,19 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpm1f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -351,7 +369,7 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto size = in.data_size();
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (is_floating_point(out.dtype())) {
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
|
@@ -10,78 +10,65 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T, typename VT, int N>
|
||||
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
size_t s = stride;
|
||||
T* a = accum;
|
||||
while (s >= N) {
|
||||
VT val = (*(VT*)x);
|
||||
*(VT*)a += val;
|
||||
x += N;
|
||||
a += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*a++ += *x++;
|
||||
}
|
||||
}
|
||||
}
|
||||
namespace {
|
||||
|
||||
// TODO: Add proper templates for the strided reduce algorithm so we don't have
|
||||
// to write max/min/sum etc.
|
||||
template <typename T, typename VT, int N>
|
||||
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
size_t s = stride;
|
||||
T* a = accum;
|
||||
while (s >= N) {
|
||||
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a));
|
||||
x += N;
|
||||
a += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*a = std::max(*a, *x);
|
||||
a++;
|
||||
x++;
|
||||
}
|
||||
template <typename T, typename VT>
|
||||
struct MinReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename VT, int N>
|
||||
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
size_t s = stride;
|
||||
T* a = accum;
|
||||
while (s >= N) {
|
||||
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
|
||||
x += N;
|
||||
a += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*a = std::min(*a, *x);
|
||||
a++;
|
||||
x++;
|
||||
}
|
||||
VT operator()(VT a, VT b) {
|
||||
return simd_min(a, b);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT, int N>
|
||||
void _vectorized_sum(const T* x, T* accum, int size) {
|
||||
VT _sum = {0};
|
||||
while (size >= N) {
|
||||
_sum += (*(VT*)x);
|
||||
x += N;
|
||||
size -= N;
|
||||
template <typename T, typename VT>
|
||||
struct MaxReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return std::max(a, b);
|
||||
}
|
||||
T sum = _sum[0];
|
||||
for (int i = 1; i < N; i++) {
|
||||
sum += _sum[i];
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
}
|
||||
*accum += sum;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct SumReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT, int N, typename Reduction>
|
||||
struct StridedReduce {
|
||||
void operator()(const T* x, T* accum, int size, size_t stride) {
|
||||
Reduction op;
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
size_t s = stride;
|
||||
T* a = accum;
|
||||
while (s >= N) {
|
||||
*(VT*)a = op((*(VT*)x), (*(VT*)a));
|
||||
x += N;
|
||||
a += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*a = op(*a, *x);
|
||||
a++;
|
||||
x++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
@@ -94,10 +81,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out,
|
||||
axes_,
|
||||
0,
|
||||
[](const auto* x, auto* accum, int size, size_t stride) {
|
||||
_vectorized_strided_sum<float, simd_float16, 16>(
|
||||
(const float*)x, (float*)accum, size, stride);
|
||||
},
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
SumReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float acc;
|
||||
vDSP_sve((const float*)x, 1, &acc, size);
|
||||
@@ -111,10 +99,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out,
|
||||
axes_,
|
||||
-std::numeric_limits<float>::infinity(),
|
||||
[](const auto* x, auto* accum, int size, size_t stride) {
|
||||
_vectorized_strided_max<float, simd_float16, 16>(
|
||||
(const float*)x, (float*)accum, size, stride);
|
||||
},
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
MaxReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float max;
|
||||
vDSP_maxv((const float*)x, 1, &max, size);
|
||||
@@ -128,10 +117,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out,
|
||||
axes_,
|
||||
std::numeric_limits<float>::infinity(),
|
||||
[](const auto* x, auto* accum, int size, size_t stride) {
|
||||
_vectorized_strided_min<float, simd_float16, 16>(
|
||||
(const float*)x, (float*)accum, size, stride);
|
||||
},
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
MinReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float min;
|
||||
vDSP_minv((const float*)x, 1, &min, size);
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
@@ -201,7 +201,7 @@ struct NeonFp16SimdOps {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT, typename Ops, int N>
|
||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
|
||||
@@ -218,13 +218,21 @@ void softmax(const array& in, array& out) {
|
||||
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
|
||||
VT vals;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vals = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vals[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vmaximum = ops.max(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
T maximum = ops.reduce_max(vmaximum);
|
||||
AccT maximum = ops.reduce_max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, *current_in_ptr);
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
@@ -234,18 +242,29 @@ void softmax(const array& in, array& out) {
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
|
||||
ops.store(current_out_ptr, vexp);
|
||||
*(VT*)current_out_ptr = vexp;
|
||||
VT vexp;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vexp = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vexp = ops.exp(ops.sub(vexp, maximum));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = ops.add(vnormalizer, vexp);
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
T normalizer = ops.reduce_add(vnormalizer);
|
||||
AccT normalizer = ops.reduce_add(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
T _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = _exp;
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
@@ -254,14 +273,33 @@ void softmax(const array& in, array& out) {
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
} else {
|
||||
VT vexp;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
current_out_ptr[i] = vexp[i];
|
||||
}
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*current_out_ptr *= normalizer;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
@@ -308,15 +346,29 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
|
||||
in, out);
|
||||
softmax<
|
||||
float,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
softmax<
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
if (precise_) {
|
||||
softmax<
|
||||
float16_t,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
} else {
|
||||
softmax<
|
||||
float16_t,
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
eval(inputs, out);
|
||||
|
@@ -41,10 +41,10 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
@@ -53,6 +53,8 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
)
|
||||
|
||||
|
@@ -179,18 +179,16 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
} else {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
}
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[logaddexp] Cannot compute logaddexp for arrays with"
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -81,13 +82,27 @@ std::string build_lib_name(
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
NodeNamer namer;
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// Fill the input names. This is not really necessary, I just like having A,
|
||||
// B, C, ... as the inputs.
|
||||
for (auto& x : inputs) {
|
||||
namer.get_name(x);
|
||||
}
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// computation performed
|
||||
a.primitive().print(os);
|
||||
// name of inputs to the function
|
||||
for (auto& inp : a.inputs()) {
|
||||
os << namer.get_name(inp);
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
|
||||
@@ -111,4 +126,102 @@ std::string build_lib_name(
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape) {
|
||||
bool contiguous = true;
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
bool all_col_contig = true;
|
||||
int non_scalar_inputs = 0;
|
||||
for (const auto& x : inputs) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
non_scalar_inputs++;
|
||||
bool shape_eq = x.shape() == shape;
|
||||
all_contig &= (x.flags().contiguous && shape_eq);
|
||||
all_row_contig &= (x.flags().row_contiguous && shape_eq);
|
||||
all_col_contig &= (x.flags().col_contiguous && shape_eq);
|
||||
}
|
||||
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 1 && !all_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 0 && !shape.empty()) {
|
||||
contiguous = false;
|
||||
}
|
||||
return contiguous;
|
||||
}
|
||||
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers /* = false */) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Correct size
|
||||
// - Not a scalar
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o++].move_shared_buffer(in);
|
||||
} else {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
strides = in.strides();
|
||||
flags = in.flags();
|
||||
data_size = in.data_size();
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
}
|
||||
} else {
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
} else {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
}
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -53,4 +53,18 @@ inline bool is_scalar(const array& x) {
|
||||
return x.ndim() == 0;
|
||||
}
|
||||
|
||||
// Check if we can use a contiguous operation given inputs and the output shape
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape);
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers = false);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -52,8 +52,25 @@ void* compile(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string kernel_file_name;
|
||||
|
||||
// Deal with long kernel names. Maximum length for files on macOS is 255
|
||||
// characters. Clip file name with a little extra room and append a 16
|
||||
// character hash.
|
||||
constexpr int max_file_name_length = 245;
|
||||
if (kernel_name.size() > max_file_name_length) {
|
||||
std::ostringstream file_name;
|
||||
file_name
|
||||
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
|
||||
auto file_id = std::hash<std::string>{}(kernel_name);
|
||||
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
|
||||
kernel_file_name = file_name.str();
|
||||
} else {
|
||||
kernel_file_name = kernel_name;
|
||||
}
|
||||
|
||||
std::ostringstream shared_lib_name;
|
||||
shared_lib_name << "lib" << kernel_name << ".so";
|
||||
shared_lib_name << "lib" << kernel_file_name << ".so";
|
||||
auto shared_lib_path = get_temp_file(shared_lib_name.str());
|
||||
bool lib_exists = false;
|
||||
{
|
||||
@@ -64,7 +81,7 @@ void* compile(
|
||||
if (!lib_exists) {
|
||||
// Open source file and write source code to it
|
||||
std::ostringstream source_file_name;
|
||||
source_file_name << kernel_name << ".cpp";
|
||||
source_file_name << kernel_file_name << ".cpp";
|
||||
auto source_file_path = get_temp_file(source_file_name.str());
|
||||
|
||||
std::ofstream source_file(source_file_path);
|
||||
@@ -248,28 +265,7 @@ void Compiled::eval_cpu(
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
{
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
bool all_col_contig = true;
|
||||
int non_scalar_inputs = 0;
|
||||
for (auto& x : inputs) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
non_scalar_inputs++;
|
||||
bool shape_eq = x.shape() == shape;
|
||||
all_contig &= (x.flags().contiguous && shape_eq);
|
||||
all_row_contig &= (x.flags().row_contiguous && shape_eq);
|
||||
all_col_contig &= (x.flags().col_contiguous && shape_eq);
|
||||
}
|
||||
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 1 && !all_contig) {
|
||||
contiguous = false;
|
||||
}
|
||||
}
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
std::vector<void*> args;
|
||||
@@ -342,56 +338,8 @@ void Compiled::eval_cpu(
|
||||
fn_ptr = compile(kernel_name, kernel.str());
|
||||
}
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
strides = in.strides();
|
||||
flags = in.flags();
|
||||
data_size = in.data_size();
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
}
|
||||
} else {
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
@@ -25,121 +25,196 @@ void copy_vector(const array& src, array& dst) {
|
||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_dim1(const array& src, array& dst) {
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim1(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
size_t src_idx = 0;
|
||||
size_t dst_idx = 0;
|
||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||
stride_t src_idx = i_offset;
|
||||
stride_t dst_idx = 0;
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += src.strides()[0];
|
||||
src_idx += i_strides[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_dim2(const array& src, array& dst) {
|
||||
inline void copy_general_dim1(const array& src, array& dst) {
|
||||
return copy_general_dim1<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim2(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
size_t src_idx = 0;
|
||||
size_t dst_idx = 0;
|
||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
||||
stride_t src_idx = i_offset;
|
||||
stride_t dst_idx = 0;
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
for (int j = 0; j < data_shape[1]; ++j) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += src.strides()[1];
|
||||
src_idx += i_strides[1];
|
||||
}
|
||||
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
||||
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_dim3(const array& src, array& dst) {
|
||||
inline void copy_general_dim2(const array& src, array& dst) {
|
||||
return copy_general_dim2<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim3(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
size_t src_idx = 0;
|
||||
size_t dst_idx = 0;
|
||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < src.shape()[2]; ++k) {
|
||||
stride_t src_idx = i_offset;
|
||||
stride_t dst_idx = 0;
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
for (int j = 0; j < data_shape[1]; ++j) {
|
||||
for (int k = 0; k < data_shape[2]; ++k) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += src.strides()[2];
|
||||
src_idx += i_strides[2];
|
||||
}
|
||||
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
|
||||
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
|
||||
}
|
||||
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
||||
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_dim4(const array& src, array& dst) {
|
||||
inline void copy_general_dim3(const array& src, array& dst) {
|
||||
return copy_general_dim3<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim4(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
size_t src_idx = 0;
|
||||
size_t dst_idx = 0;
|
||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < src.shape()[2]; ++k) {
|
||||
for (size_t ii = 0; ii < src.shape()[3]; ++ii) {
|
||||
stride_t src_idx = i_offset;
|
||||
stride_t dst_idx = 0;
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
for (int j = 0; j < data_shape[1]; ++j) {
|
||||
for (int k = 0; k < data_shape[2]; ++k) {
|
||||
for (int ii = 0; ii < data_shape[3]; ++ii) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += src.strides()[3];
|
||||
src_idx += i_strides[3];
|
||||
}
|
||||
src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3];
|
||||
src_idx += i_strides[2] - i_strides[3] * data_shape[3];
|
||||
}
|
||||
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
|
||||
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
|
||||
}
|
||||
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
||||
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general(const array& src, array& dst) {
|
||||
inline void copy_general_dim4(const array& src, array& dst) {
|
||||
return copy_general_dim4<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
switch (src.ndim()) {
|
||||
case 1:
|
||||
copy_general_dim1<SrcT, DstT>(src, dst);
|
||||
copy_general_dim1<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
case 2:
|
||||
copy_general_dim2<SrcT, DstT>(src, dst);
|
||||
copy_general_dim2<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
case 3:
|
||||
copy_general_dim3<SrcT, DstT>(src, dst);
|
||||
copy_general_dim3<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
case 4:
|
||||
copy_general_dim4<SrcT, DstT>(src, dst);
|
||||
copy_general_dim4<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
}
|
||||
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
for (size_t i = 0; i < dst.size(); ++i) {
|
||||
size_t src_elem = elem_to_loc(i, src.shape(), src.strides());
|
||||
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
|
||||
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, int D>
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general(const array& src, array& dst) {
|
||||
return copy_general<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
inline void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
return copy_general<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t, int D>
|
||||
inline void copy_general_general_dims(
|
||||
const array& src,
|
||||
array& dst,
|
||||
size_t offset_src,
|
||||
size_t offset_dst) {
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
stride_t i_offset,
|
||||
stride_t o_offset) {
|
||||
if constexpr (D > 1) {
|
||||
int axis = src.ndim() - D;
|
||||
auto stride_src = src.strides()[axis];
|
||||
auto stride_dst = dst.strides()[axis];
|
||||
auto N = src.shape(axis);
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
auto N = data_shape[axis];
|
||||
for (int i = 0; i < N; i++) {
|
||||
copy_general_general_dims<SrcT, DstT, D - 1>(
|
||||
src, dst, offset_src, offset_dst);
|
||||
offset_src += stride_src;
|
||||
offset_dst += stride_dst;
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
i_offset += stride_src;
|
||||
o_offset += stride_dst;
|
||||
}
|
||||
} else {
|
||||
int axis = src.ndim() - 1;
|
||||
auto stride_src = src.strides()[axis];
|
||||
auto stride_dst = dst.strides()[axis];
|
||||
auto N = src.shape(axis);
|
||||
const SrcT* src_ptr = src.data<SrcT>() + offset_src;
|
||||
DstT* dst_ptr = dst.data<DstT>() + offset_dst;
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
auto N = data_shape[axis];
|
||||
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
|
||||
DstT* dst_ptr = dst.data<DstT>() + o_offset;
|
||||
for (int i = 0; i < N; i++) {
|
||||
*dst_ptr = static_cast<DstT>(*src_ptr);
|
||||
src_ptr += stride_src;
|
||||
@@ -148,37 +223,56 @@ inline void copy_general_general_dims(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_general(const array& src, array& dst) {
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
stride_t i_offset,
|
||||
stride_t o_offset) {
|
||||
switch (src.ndim()) {
|
||||
case 1:
|
||||
copy_general_general_dims<SrcT, DstT, 1>(src, dst, 0, 0);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 2:
|
||||
copy_general_general_dims<SrcT, DstT, 2>(src, dst, 0, 0);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 3:
|
||||
copy_general_general_dims<SrcT, DstT, 3>(src, dst, 0, 0);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 4:
|
||||
copy_general_general_dims<SrcT, DstT, 4>(src, dst, 0, 0);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 5:
|
||||
copy_general_general_dims<SrcT, DstT, 5>(src, dst, 0, 0);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
}
|
||||
|
||||
int size = std::accumulate(
|
||||
src.shape().begin() - 5, src.shape().end(), 1, std::multiplies<int>());
|
||||
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||
for (int i = 0; i < src.size(); i += size) {
|
||||
size_t offset_src = elem_to_loc(i, src.shape(), src.strides());
|
||||
size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides());
|
||||
copy_general_general_dims<SrcT, DstT, 5>(src, dst, offset_src, offset_dst);
|
||||
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
|
||||
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
inline void copy_general_general(const array& src, array& dst) {
|
||||
return copy_general_general<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
copy_single<SrcT, DstT>(src, dst);
|
||||
@@ -187,54 +281,103 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
||||
copy_vector<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case CopyType::General:
|
||||
copy_general<SrcT, DstT>(src, dst);
|
||||
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
return;
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_general_general<SrcT, DstT>(src, dst);
|
||||
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT>
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
template <typename SrcT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (dst.dtype()) {
|
||||
case bool_:
|
||||
copy<SrcT, bool>(src, dst, ctype);
|
||||
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<SrcT, uint8_t>(src, dst, ctype);
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<SrcT, uint16_t>(src, dst, ctype);
|
||||
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<SrcT, uint32_t>(src, dst, ctype);
|
||||
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<SrcT, uint64_t>(src, dst, ctype);
|
||||
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<SrcT, int8_t>(src, dst, ctype);
|
||||
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<SrcT, int16_t>(src, dst, ctype);
|
||||
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<SrcT, int32_t>(src, dst, ctype);
|
||||
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<SrcT, int64_t>(src, dst, ctype);
|
||||
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<SrcT, float16_t>(src, dst, ctype);
|
||||
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<SrcT, float>(src, dst, ctype);
|
||||
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype);
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<SrcT, complex64_t>(src, dst, ctype);
|
||||
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
inline void copy_inplace_dispatch(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Args&&... args) {
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -242,47 +385,7 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
||||
} // namespace
|
||||
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
copy<bool>(src, dst, ctype);
|
||||
break;
|
||||
case uint8:
|
||||
copy<uint8_t>(src, dst, ctype);
|
||||
break;
|
||||
case uint16:
|
||||
copy<uint16_t>(src, dst, ctype);
|
||||
break;
|
||||
case uint32:
|
||||
copy<uint32_t>(src, dst, ctype);
|
||||
break;
|
||||
case uint64:
|
||||
copy<uint64_t>(src, dst, ctype);
|
||||
break;
|
||||
case int8:
|
||||
copy<int8_t>(src, dst, ctype);
|
||||
break;
|
||||
case int16:
|
||||
copy<int16_t>(src, dst, ctype);
|
||||
break;
|
||||
case int32:
|
||||
copy<int32_t>(src, dst, ctype);
|
||||
break;
|
||||
case int64:
|
||||
copy<int64_t>(src, dst, ctype);
|
||||
break;
|
||||
case float16:
|
||||
copy<float16_t>(src, dst, ctype);
|
||||
break;
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype);
|
||||
break;
|
||||
case complex64:
|
||||
copy<complex64_t>(src, dst, ctype);
|
||||
break;
|
||||
}
|
||||
return copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
@@ -312,4 +415,62 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
||||
copy_inplace(src, dst, ctype);
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype) {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
return copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset);
|
||||
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
return copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void copy_inplace<int64_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<int64_t>& i_strides,
|
||||
const std::vector<int64_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype) {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
return copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset);
|
||||
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
return copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -26,4 +26,15 @@ enum class CopyType {
|
||||
void copy(const array& src, array& dst, CopyType ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -41,6 +41,7 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
@@ -51,11 +52,13 @@ DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(NumberOfElements)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(Expm1)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Full)
|
||||
@@ -93,6 +96,7 @@ DEFAULT(Sign)
|
||||
DEFAULT(Sin)
|
||||
DEFAULT(Sinh)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT(Softmax)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT_MULTI(Split)
|
||||
@@ -100,9 +104,11 @@ DEFAULT(Square)
|
||||
DEFAULT(Sqrt)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Subtract)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
|
||||
namespace {
|
||||
|
||||
|
104
mlx/backend/common/inverse.cpp
Normal file
104
mlx/backend/common/inverse.cpp
Normal file
@@ -0,0 +1,104 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void inverse_impl(const array& a, array& inv) {
|
||||
// Lapack uses the column-major convention. We take advantage of the following
|
||||
// identity to avoid transposing (see
|
||||
// https://math.stackexchange.com/a/340234):
|
||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||
|
||||
// The inverse is computed in place, so just copy the input to the output.
|
||||
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute LU factorization.
|
||||
sgetrf_(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
static const int lwork_query = -1;
|
||||
float workspace_size = 0;
|
||||
|
||||
// Compute workspace size.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ nullptr,
|
||||
/* work = */ &workspace_size,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch =
|
||||
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
||||
}
|
||||
inverse_impl(inputs[0], output);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
return {{linalg::inv(a, stream())}, {ax}};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
23
mlx/backend/common/lapack_helper.h
Normal file
23
mlx/backend/common/lapack_helper.h
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME)
|
||||
|
||||
// This is to work around a change in the function signatures of lapack >= 3.9.1
|
||||
// where functions taking char* also include a strlen argument, see a similar
|
||||
// change in OpenCV:
|
||||
// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57
|
||||
#define MLX_LAPACK_FUNC(f) LAPACK_##f
|
||||
|
||||
#else
|
||||
|
||||
#define MLX_LAPACK_FUNC(f) f##_
|
||||
|
||||
#endif
|
193
mlx/backend/common/masked_mm.cpp
Normal file
193
mlx/backend/common/masked_mm.cpp
Normal file
@@ -0,0 +1,193 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
inline void mask_matrix(
|
||||
T* data,
|
||||
const bool* mask,
|
||||
int block_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str) {
|
||||
int tX = (X + block_size - 1) / block_size;
|
||||
int tY = (Y + block_size - 1) / block_size;
|
||||
|
||||
for (int i = 0; i < tX; i++) {
|
||||
for (int j = 0; j < tY; j++) {
|
||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
||||
if (!do_mask) {
|
||||
int loc_x = i * block_size;
|
||||
int loc_y = j * block_size;
|
||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||
|
||||
int size_x = std::min(block_size, X - loc_x);
|
||||
int size_y = std::min(block_size, Y - loc_y);
|
||||
for (int ii = 0; ii < size_x; ii++) {
|
||||
for (int jj = 0; jj < size_y; jj++) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[BlockMaskedMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
auto& out_mask = inputs[2];
|
||||
|
||||
auto check_transpose = [](const array& arr, bool do_copy) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(true, sty, arr_copy);
|
||||
}
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
bool has_op_mask = inputs.size() > 3;
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_array = [](const array& mask,
|
||||
float* data,
|
||||
int block_size,
|
||||
int batch_idx,
|
||||
int X,
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str) {
|
||||
const bool* mask_ptr = mask.data<bool>() +
|
||||
elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||
mask.shape(),
|
||||
mask.strides());
|
||||
|
||||
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask_ptr,
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
X_data_str,
|
||||
Y_data_str,
|
||||
X_mask_str,
|
||||
Y_mask_str);
|
||||
};
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
// Adjust pointer
|
||||
float* ai =
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
float* bi =
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
float* ci = out.data<float>() + M * N * i;
|
||||
|
||||
// Zero out blocks in a and b if needed
|
||||
if (has_op_mask) {
|
||||
auto& a_mask = inputs[3];
|
||||
mask_array(
|
||||
a_mask,
|
||||
ai,
|
||||
block_size_,
|
||||
i,
|
||||
M,
|
||||
K,
|
||||
a_transposed ? 1 : lda,
|
||||
a_transposed ? lda : 1);
|
||||
|
||||
auto& b_mask = inputs[4];
|
||||
mask_array(
|
||||
b_mask,
|
||||
bi,
|
||||
block_size_,
|
||||
i,
|
||||
K,
|
||||
N,
|
||||
b_transposed ? 1 : ldb,
|
||||
b_transposed ? ldb : 1);
|
||||
}
|
||||
|
||||
// Do matmul
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0, // alpha
|
||||
ai,
|
||||
lda,
|
||||
bi,
|
||||
ldb,
|
||||
0.0, // beta
|
||||
ci,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
|
||||
// Zero out blocks in out
|
||||
mask_array(out_mask, ci, block_size_, i, M, N, N, 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -241,6 +241,13 @@ struct Exp {
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return expm1(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@@ -22,7 +22,7 @@ namespace mlx::core {
|
||||
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (is_unsigned(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), unsignedinteger)) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
@@ -37,7 +37,7 @@ void Arange::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcCos());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -49,7 +49,7 @@ void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcCosh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -61,7 +61,7 @@ void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcSin());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -73,7 +73,7 @@ void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcSinh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -85,7 +85,7 @@ void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcTan());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -97,7 +97,7 @@ void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -171,7 +171,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Ceil());
|
||||
} else {
|
||||
// No-op integer types
|
||||
@@ -211,7 +211,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Cos());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -223,7 +223,7 @@ void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Cosh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -251,6 +251,62 @@ void Depends::eval(
|
||||
}
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
numel *= inputs[0].shape(ax);
|
||||
}
|
||||
|
||||
if (inverted_) {
|
||||
numel = 1.0 / numel;
|
||||
}
|
||||
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
*out.data<bool>() = static_cast<bool>(numel);
|
||||
break;
|
||||
case uint8:
|
||||
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
|
||||
break;
|
||||
case uint16:
|
||||
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
|
||||
break;
|
||||
case uint32:
|
||||
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
|
||||
break;
|
||||
case uint64:
|
||||
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
|
||||
break;
|
||||
case int8:
|
||||
*out.data<int8_t>() = static_cast<int8_t>(numel);
|
||||
break;
|
||||
case int16:
|
||||
*out.data<int16_t>() = static_cast<int16_t>(numel);
|
||||
break;
|
||||
case int32:
|
||||
*out.data<int32_t>() = static_cast<int32_t>(numel);
|
||||
break;
|
||||
case int64:
|
||||
*out.data<int64_t>() = static_cast<int64_t>(numel);
|
||||
break;
|
||||
case float16:
|
||||
*out.data<float16_t>() = static_cast<float16_t>(numel);
|
||||
break;
|
||||
case float32:
|
||||
*out.data<float>() = static_cast<float>(numel);
|
||||
break;
|
||||
case bfloat16:
|
||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||
break;
|
||||
case complex64:
|
||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
@@ -294,7 +350,7 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Exp());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -303,10 +359,22 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Expm1());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[expm1] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Floor());
|
||||
} else {
|
||||
// No-op integer types
|
||||
@@ -332,7 +400,7 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_fp(in, out, detail::Log());
|
||||
@@ -354,7 +422,7 @@ void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Log1p::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Log1p());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -468,27 +536,80 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous) {
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
}
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
auto [shape, _strides] = collapse_contiguous_dims(in);
|
||||
auto& strides = _strides[0];
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
std::vector<size_t> out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
int N = out.shape(i);
|
||||
if (j < shape.size() && shape[j] % N == 0) {
|
||||
shape[j] /= N;
|
||||
out_strides.push_back(shape[j] * strides[j]);
|
||||
j += (shape[j] == 1);
|
||||
} else if (N == 1) {
|
||||
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
|
||||
out_strides.push_back(out_strides.back());
|
||||
} else {
|
||||
copy_necessary = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.row_contiguous) {
|
||||
// For row contiguous reshapes:
|
||||
// - Shallow copy the buffer
|
||||
// - If reshaping into a vector (all singleton dimensions except one) it
|
||||
// becomes col contiguous again.
|
||||
auto flags = in.flags();
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
|
||||
} else {
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Round());
|
||||
} else {
|
||||
// No-op integer types
|
||||
@@ -499,7 +620,7 @@ void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sigmoid());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -521,7 +642,7 @@ void Sign::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sin());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -533,7 +654,7 @@ void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sinh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -542,36 +663,33 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
auto strides = in.strides();
|
||||
auto flags = in.flags();
|
||||
size_t data_offset = 0;
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
strides[i] *= strides_[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
|
||||
copy_needed |= strides_[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void Slice::shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.row_contiguous = true;
|
||||
flags.col_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; ri >= 0; i++, ri--) {
|
||||
flags.col_contiguous &= strides[i] == f_stride || out.shape(i) == 1;
|
||||
flags.row_contiguous &= strides[ri] == b_stride || out.shape(ri) == 1;
|
||||
f_stride *= out.shape(i);
|
||||
b_stride *= out.shape(ri);
|
||||
if (strides[i] > 0) {
|
||||
data_size *= out.shape(i);
|
||||
}
|
||||
}
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
@@ -585,7 +703,87 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_inplace<int64_t>(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General);
|
||||
} else {
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, out);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
}
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if materialization is needed
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(out);
|
||||
|
||||
// Do copy
|
||||
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_inplace<int64_t>(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ upd_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
@@ -664,7 +862,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Tan());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -676,7 +874,7 @@ void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Tanh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
|
@@ -6,8 +6,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum ReductionOpType {
|
||||
// Self-explanatory. Read everything and produce 1 output.
|
||||
ContiguousAllReduce,
|
||||
@@ -38,6 +36,21 @@ enum ReductionOpType {
|
||||
GeneralReduce
|
||||
};
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
// Should this be in utils?
|
||||
inline void nd_loop(
|
||||
@@ -110,19 +123,6 @@ struct DefaultContiguousReduce {
|
||||
}
|
||||
};
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
|
@@ -1,13 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void RoPE::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -222,7 +222,7 @@ void scan_dispatch(
|
||||
}
|
||||
case Scan::Min: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
|
||||
auto init = (is_floating_point(input.dtype()))
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
@@ -232,7 +232,7 @@ void scan_dispatch(
|
||||
}
|
||||
case Scan::Max: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||
auto init = (is_floating_point(input.dtype()))
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@@ -10,7 +10,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename AccT>
|
||||
void softmax(const array& in, array& out) {
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
@@ -22,26 +22,36 @@ void softmax(const array& in, array& out) {
|
||||
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
T maximum = *current_in_ptr;
|
||||
AccT maximum = *current_in_ptr;
|
||||
for (int j = 0; j < N; j++, current_in_ptr++) {
|
||||
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum;
|
||||
maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
|
||||
: maximum;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
T normalizer = 0;
|
||||
AccT normalizer = 0;
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
|
||||
T expv = std::exp(*current_in_ptr - maximum);
|
||||
AccT expv = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += expv;
|
||||
*current_out_ptr = expv;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = expv;
|
||||
}
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_in_ptr = in_ptr;
|
||||
current_out_ptr = out_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++) {
|
||||
*current_out_ptr *= normalizer;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
auto v = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(v * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -67,11 +77,15 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
array in = check_input(std::move(inputs[0]));
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
@@ -87,13 +101,21 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float>(in, out);
|
||||
softmax<float, float>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
softmax<float16_t>(in, out);
|
||||
if (precise_) {
|
||||
softmax<float16_t, float>(in, out);
|
||||
} else {
|
||||
softmax<float16_t, float16_t>(in, out);
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
softmax<bfloat16_t>(in, out);
|
||||
if (precise_) {
|
||||
softmax<bfloat16_t, float>(in, out);
|
||||
} else {
|
||||
softmax<bfloat16_t, bfloat16_t>(in, out);
|
||||
}
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
|
156
mlx/backend/common/svd.cpp
Normal file
156
mlx/backend/common/svd.cpp
Normal file
@@ -0,0 +1,156 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack_helper.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
// Lapack uses the column-major convention. To avoid having to transpose
|
||||
// the input and then transpose the outputs, we swap the indices/sizes of the
|
||||
// matrices and take advantage of the following identity (see
|
||||
// https://math.stackexchange.com/a/30077)
|
||||
// A = UΣVᵀ
|
||||
// Aᵀ = VΣUᵀ
|
||||
// As a result some of the indices/sizes are swapped as noted above.
|
||||
|
||||
// Rows and cols of the original matrix in row-major order.
|
||||
const int M = a.shape(-2);
|
||||
const int N = a.shape(-1);
|
||||
const int K = std::min(M, N);
|
||||
|
||||
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
|
||||
const int lda = N;
|
||||
// U of shape M x M. (N x N in lapack).
|
||||
const int ldu = N;
|
||||
// Vᵀ of shape N x N. (M x M in lapack).
|
||||
const int ldvt = M;
|
||||
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// lapack clobbers the input, so we have to make a copy.
|
||||
array in(a.shape(), float32, nullptr, {});
|
||||
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
// Allocate outputs.
|
||||
u.set_data(allocator::malloc_or_wait(u.nbytes()));
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
||||
|
||||
static constexpr auto job_u = "V";
|
||||
static constexpr auto job_vt = "V";
|
||||
static constexpr auto range = "A";
|
||||
|
||||
// Will contain the number of singular values after the call has returned.
|
||||
int ns = 0;
|
||||
float workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not used
|
||||
// here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
static const int ignored_int = 0;
|
||||
static const float ignored_float = 0;
|
||||
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
MLX_LAPACK_FUNC(sgesvdx)
|
||||
(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
/* vt = */ nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_dimension;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
MLX_LAPACK_FUNC(sgesvdx)
|
||||
(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ in.data<float>() + M * N * i,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ s.data<float>() + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt.data<float>() + N * N * i,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ u.data<float>() + M * M * i,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
if (ns != K) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: expected " << K << " singular values, but " << ns
|
||||
<< " were computed.";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
if (!(inputs[0].dtype() == float32)) {
|
||||
throw std::runtime_error("[SVD::eval] only supports float32.");
|
||||
}
|
||||
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -8,11 +8,12 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
template <typename stride_t>
|
||||
inline stride_t elem_to_loc(
|
||||
int elem,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
size_t loc = 0;
|
||||
const std::vector<stride_t>& strides) {
|
||||
stride_t loc = 0;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(elem, shape[i]);
|
||||
loc += q_and_r.rem * strides[i];
|
||||
@@ -28,4 +29,93 @@ inline size_t elem_to_loc(int elem, const array& a) {
|
||||
return elem_to_loc(elem, a.shape(), a.strides());
|
||||
}
|
||||
|
||||
// Collapse dims that are contiguous to possibly route to a better kernel
|
||||
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
||||
// should return {{2, 4}, {{1, 2}}}.
|
||||
//
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
template <typename stride_t>
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<stride_t>> strides) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
if (shape.size() > 0) {
|
||||
to_collapse.push_back(0);
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
for (const std::vector<stride_t>& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1]) {
|
||||
contiguous = false;
|
||||
}
|
||||
if (!contiguous) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!contiguous) {
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
to_collapse.push_back(i);
|
||||
}
|
||||
to_collapse.push_back(-1);
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<stride_t>> out_strides(strides.size());
|
||||
for (int i = 0; i < to_collapse.size(); i++) {
|
||||
int current_shape = shape[to_collapse[i]];
|
||||
while (to_collapse[++i] != -1) {
|
||||
current_shape *= shape[to_collapse[i]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<stride_t>& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (auto& x : xs) {
|
||||
strides.emplace_back(x.strides());
|
||||
}
|
||||
return collapse_contiguous_dims(xs[0].shape(), strides);
|
||||
}
|
||||
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
||||
return collapse_contiguous_dims(
|
||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
inline auto check_contiguity(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<stride_t>& strides) {
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
bool is_row_contiguous = true;
|
||||
bool is_col_contiguous = true;
|
||||
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
data_size *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -26,6 +26,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
@@ -33,6 +34,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <unistd.h>
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/compiled_preamble.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@@ -228,14 +229,7 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
for (auto& x : inputs) {
|
||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
||||
!is_scalar(x)) {
|
||||
contiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
@@ -295,7 +289,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Put the inputs in
|
||||
@@ -306,7 +300,7 @@ void Compiled::eval_gpu(
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
compute_encoder.set_input_array(x, cnt++);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
@@ -316,30 +310,12 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
{
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].move_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, true);
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
compute_encoder.set_output_array(x, cnt++);
|
||||
}
|
||||
|
||||
// Put the output shape and strides in
|
||||
|
@@ -28,10 +28,12 @@ void explicit_gemm_conv_ND_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<N>& conv_params) {
|
||||
// Get gemm shapes
|
||||
int implicit_M = out.size() / conv_params.O;
|
||||
int implicit_K = wt.size() / conv_params.O;
|
||||
int implicit_N = conv_params.O;
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape = {
|
||||
static_cast<int>(out.size() / conv_params.O),
|
||||
static_cast<int>(wt.size() / conv_params.O)};
|
||||
std::vector<int> unfolded_shape{implicit_M, implicit_K};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
@@ -39,12 +41,12 @@ void explicit_gemm_conv_ND_gpu(
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, in_unfolded, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
|
||||
@@ -59,20 +61,29 @@ void explicit_gemm_conv_ND_gpu(
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
|
||||
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
|
||||
auto wt_flags = wt.flags();
|
||||
wt_flags.row_contiguous = false;
|
||||
wt_flags.col_contiguous = true;
|
||||
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies;
|
||||
std::vector<array> copies = {in_unfolded, wt_reshaped};
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_unfolded,
|
||||
/*b = */ wt,
|
||||
/*b = */ wt_reshaped,
|
||||
/*c = */ out,
|
||||
/*M = */ unfolded_shape[0],
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ unfolded_shape[1],
|
||||
/*M = */ implicit_M,
|
||||
/*N = */ implicit_N,
|
||||
/*K = */ implicit_K,
|
||||
/*batch_size_out = */ 1,
|
||||
/*a_cols = */ unfolded_shape[1],
|
||||
/*b_cols = */ unfolded_shape[1],
|
||||
/*a_cols = */ implicit_K,
|
||||
/*b_cols = */ implicit_K,
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/*copies = */ copies);
|
||||
@@ -129,7 +140,7 @@ void slow_conv_2D_gpu(
|
||||
<< "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -142,9 +153,9 @@ void slow_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(bm, bn, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
@@ -230,7 +241,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -243,9 +254,9 @@ void implicit_gemm_conv_2D_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
@@ -383,7 +394,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -397,9 +408,9 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
@@ -500,12 +511,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, wt, 0);
|
||||
set_array_buffer(compute_encoder, filt_wg, 1);
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
compute_encoder.set_output_array(filt_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(&C_c, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&O_c, sizeof(int), 3);
|
||||
@@ -528,12 +539,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in_padded, 0);
|
||||
set_array_buffer(compute_encoder, inp_wg, 1);
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
compute_encoder.set_output_array(inp_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
@@ -576,12 +587,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, out_wg, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
@@ -12,8 +12,15 @@ namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
if (in.dtype() == out.dtype()) {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
@@ -37,15 +44,22 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& strides_in_pre,
|
||||
const std::vector<stride_t>& strides_out_pre,
|
||||
int64_t inp_offset,
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(in, out);
|
||||
auto& strides_in = strides[0];
|
||||
auto& strides_out = strides[1];
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector{strides_in_pre, strides_out_pre});
|
||||
auto& strides_in_ = strides[0];
|
||||
auto& strides_out_ = strides[1];
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
std::ostringstream kname;
|
||||
@@ -69,42 +83,47 @@ void copy_gpu_inplace(
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
inp_offset *= size_of(in.dtype());
|
||||
out_offset *= size_of(out.dtype());
|
||||
|
||||
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
|
||||
compute_encoder.set_output_array(out, 1, out_offset);
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
size_t ndim = shape.size();
|
||||
int ndim = shape.size();
|
||||
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3);
|
||||
}
|
||||
set_vector_bytes(compute_encoder, shape, ndim, 2);
|
||||
}
|
||||
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(
|
||||
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||
}
|
||||
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
int rest = in.size() / (dim0 * dim1);
|
||||
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
data_size *= s;
|
||||
int rest = data_size / (dim0 * dim1);
|
||||
|
||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
||||
}
|
||||
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
@@ -120,4 +139,25 @@ void copy_gpu_inplace(
|
||||
}
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int64_t>& istride,
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,12 +7,34 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Generic copy inplace
|
||||
template <typename stride_t>
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype);
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& src,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int64_t>& istride,
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <cstdlib>
|
||||
@@ -11,7 +11,9 @@
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
@@ -20,9 +22,9 @@ namespace mlx::core::metal {
|
||||
namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
|
||||
static constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
auto load_device() {
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
@@ -127,7 +129,7 @@ Device::~Device() {
|
||||
b.second.second->release();
|
||||
}
|
||||
for (auto& e : encoder_map_) {
|
||||
e.second->release();
|
||||
(*e.second)->release();
|
||||
}
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
@@ -145,6 +147,7 @@ void Device::new_queue(int index) {
|
||||
// We lock this as a critical section for safety
|
||||
const std::lock_guard<std::mutex> lock(mtx_);
|
||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||
debug_set_stream_queue_label(q, index);
|
||||
if (!q) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Failed to make new command queue.");
|
||||
@@ -197,22 +200,25 @@ void Device::commit_command_buffer(int index) {
|
||||
void Device::end_encoding(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit != encoder_map_.end()) {
|
||||
eit->second->endEncoding();
|
||||
eit->second->release();
|
||||
(*eit->second)->endEncoding();
|
||||
(*eit->second)->release();
|
||||
encoder_map_.erase(eit);
|
||||
}
|
||||
}
|
||||
|
||||
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
|
||||
CommandEncoder& Device::get_command_encoder(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit == encoder_map_.end()) {
|
||||
auto cb = get_command_buffer(index);
|
||||
auto compute_encoder = cb->computeCommandEncoder();
|
||||
auto compute_encoder =
|
||||
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
compute_encoder->retain();
|
||||
eit = encoder_map_.insert({index, compute_encoder}).first;
|
||||
eit = encoder_map_
|
||||
.emplace(index, std::make_unique<CommandEncoder>(compute_encoder))
|
||||
.first;
|
||||
}
|
||||
return eit->second;
|
||||
return *(eit->second);
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
@@ -538,11 +544,12 @@ Device& device(mlx::core::Device) {
|
||||
return metal_device;
|
||||
}
|
||||
|
||||
std::shared_ptr<void> new_scoped_memory_pool() {
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
|
||||
auto dtor = [](void* ptr) {
|
||||
static_cast<NS::AutoreleasePool*>(ptr)->release();
|
||||
};
|
||||
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
|
||||
return std::unique_ptr<void, std::function<void(void*)>>(
|
||||
NS::AutoreleasePool::alloc()->init(), dtor);
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,10 +7,12 @@
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
@@ -34,6 +36,70 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(MTL::ComputeCommandEncoder* enc)
|
||||
: enc(enc), concurrent(false){};
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
struct ConcurrentContext {
|
||||
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
|
||||
enc.concurrent = true;
|
||||
}
|
||||
~ConcurrentContext() {
|
||||
enc.concurrent = false;
|
||||
enc.outputs.insert(
|
||||
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
|
||||
enc.concurrent_outputs.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
CommandEncoder& enc;
|
||||
};
|
||||
|
||||
MTL::ComputeCommandEncoder* operator->() {
|
||||
return enc;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int offset = 0) {
|
||||
auto r_buf =
|
||||
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
// Insert a barrier
|
||||
enc->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs.erase(it);
|
||||
}
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void set_output_array(array& a, int idx, int offset = 0) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent) {
|
||||
concurrent_outputs.insert(buf);
|
||||
} else {
|
||||
outputs.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
|
||||
private:
|
||||
MTL::ComputeCommandEncoder* enc;
|
||||
bool concurrent;
|
||||
std::unordered_set<MTL::Resource*> outputs;
|
||||
std::unordered_set<MTL::Resource*> concurrent_outputs;
|
||||
};
|
||||
|
||||
class Device {
|
||||
public:
|
||||
Device();
|
||||
@@ -51,7 +117,7 @@ class Device {
|
||||
int get_command_buffer_ops(int index);
|
||||
void increment_command_buffer_ops(int index);
|
||||
void commit_command_buffer(int index);
|
||||
MTL::ComputeCommandEncoder* get_command_encoder(int index);
|
||||
CommandEncoder& get_command_encoder(int index);
|
||||
void end_encoding(int index);
|
||||
|
||||
void register_library(
|
||||
@@ -132,7 +198,7 @@ class Device {
|
||||
MTL::Device* device_;
|
||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||
std::unordered_map<int32_t, MTL::ComputeCommandEncoder*> encoder_map_;
|
||||
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
|
||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||
std::mutex mtx_;
|
||||
|
30
mlx/backend/metal/event.cpp
Normal file
30
mlx/backend/metal/event.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
Event::Event(const Stream& stream) : stream_(stream) {
|
||||
auto dtor = [](void* ptr) {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
static_cast<MTL::SharedEvent*>(ptr)->release();
|
||||
};
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
event_ = std::shared_ptr<void>(
|
||||
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
|
||||
}
|
||||
|
||||
void Event::wait() {
|
||||
if (!static_cast<MTL::SharedEvent*>(raw_event().get())
|
||||
->waitUntilSignaledValue(value(), -1)) {
|
||||
throw std::runtime_error("[Event::wait] Timed out");
|
||||
}
|
||||
}
|
||||
|
||||
void Event::signal() {
|
||||
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,12 +1,106 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& in = inputs[0];
|
||||
throw std::runtime_error("[FFT] NYI for Metal backend.");
|
||||
|
||||
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
|
||||
in.dtype() != complex64 || out.dtype() != complex64) {
|
||||
// Could also fallback to CPU implementation here.
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
|
||||
}
|
||||
|
||||
size_t n = in.shape(axes_[0]);
|
||||
|
||||
if (!is_power_of_2(n) || n > 2048 || n < 4) {
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
|
||||
}
|
||||
|
||||
// Make sure that the array is contiguous and has stride 1 in the FFT dim
|
||||
std::vector<array> copies;
|
||||
auto check_input = [this, &copies, &s](const array& x) {
|
||||
// TODO: Pass the strides to the kernel so
|
||||
// we can avoid the copy when x is not contiguous.
|
||||
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
|
||||
x.flags().col_contiguous;
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
std::vector<size_t> strides;
|
||||
size_t cur_stride = x.shape(axes_[0]);
|
||||
for (int axis = 0; axis < x.ndim(); axis++) {
|
||||
if (axis == axes_[0]) {
|
||||
strides.push_back(1);
|
||||
} else {
|
||||
strides.push_back(cur_stride);
|
||||
cur_stride *= x.shape(axis);
|
||||
}
|
||||
}
|
||||
|
||||
auto flags = x.flags();
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
|
||||
f_stride *= x.shape(i);
|
||||
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
|
||||
b_stride *= x.shape(ri);
|
||||
}
|
||||
// This is probably over-conservative
|
||||
flags.contiguous = false;
|
||||
|
||||
x_copy.set_data(
|
||||
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
|
||||
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(inputs[0]);
|
||||
|
||||
// TODO: allow donation here
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
in_contiguous.data_size(),
|
||||
in_contiguous.strides(),
|
||||
in_contiguous.flags());
|
||||
|
||||
// We use n / 4 threads by default since radix-4
|
||||
// is the largest single threaded radix butterfly
|
||||
// we currently implement.
|
||||
size_t m = n / 4;
|
||||
size_t batch = in.size() / in.shape(axes_[0]);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "fft_" << n;
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in_contiguous, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
auto group_dims = MTL::Size(1, m, 1);
|
||||
auto grid_dims = MTL::Size(batch, m, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -16,7 +16,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -49,7 +49,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kname << "_" << idx_ndim;
|
||||
}
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -81,8 +81,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, src, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(src, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
// Set source info
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
|
||||
@@ -103,7 +103,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
@@ -183,7 +183,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
kname << "_" << nidx;
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
auto& upd = inputs.back();
|
||||
@@ -192,8 +192,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(upd, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Set update info
|
||||
uint upd_ndim = upd.ndim();
|
||||
@@ -201,19 +201,16 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
|
||||
if (index_nd1_specialization) {
|
||||
bool upd_col_contiguous = upd.flags().col_contiguous;
|
||||
compute_encoder->setBytes(
|
||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
@@ -283,7 +280,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
|
@@ -7,6 +7,7 @@ set(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
@@ -20,9 +21,12 @@ set(
|
||||
"binary_two"
|
||||
"conv"
|
||||
"copy"
|
||||
"fft"
|
||||
"gemv"
|
||||
"quantized"
|
||||
"random"
|
||||
"rms_norm"
|
||||
"layer_norm"
|
||||
"rope"
|
||||
"scan"
|
||||
"scaled_dot_product_attention"
|
||||
@@ -35,11 +39,17 @@ set(
|
||||
)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
-gline-tables-only
|
||||
-frecord-sources)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal -Wall -Wextra
|
||||
-fno-fast-math
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
COMMAND xcrun -sdk macosx metal
|
||||
${METAL_FLAGS}
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
-o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS}
|
||||
OUTPUT ${TARGET}.air
|
||||
|
@@ -1,29 +1,29 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_s(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_v(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t& src_stride,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
@@ -31,61 +31,61 @@ template <typename T, typename U>
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[2],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[3],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_g_nd(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int src_shape[DIM],
|
||||
constant const size_t src_strides[DIM],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int* src_shape,
|
||||
constant const size_t* src_strides,
|
||||
constant const int& ndim,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd1(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t& src_stride,
|
||||
constant const size_t& dst_stride,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
@@ -94,10 +94,10 @@ template <typename T, typename U>
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[2],
|
||||
constant const size_t dst_strides[2],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
@@ -106,10 +106,10 @@ template <typename T, typename U>
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[3],
|
||||
constant const size_t dst_strides[3],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
@@ -118,11 +118,11 @@ template <typename T, typename U>
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_gg_nd(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int src_shape[DIM],
|
||||
constant const size_t src_strides[DIM],
|
||||
constant const size_t dst_strides[DIM],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||
@@ -131,12 +131,12 @@ template <typename T, typename U, int DIM>
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int* src_shape,
|
||||
constant const size_t* src_strides,
|
||||
constant const size_t* dst_strides,
|
||||
constant const int& ndim,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
||||
@@ -146,70 +146,70 @@ template <typename T, typename U>
|
||||
#define instantiate_copy(name, itype, otype, ctype) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void copy_##ctype<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void copy_g_nd<itype, otype, dims>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int src_shape[dims], \
|
||||
constant const size_t src_strides[dims], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_" #dims)]] \
|
||||
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int src_shape[dims], \
|
||||
constant const size_t src_strides[dims], \
|
||||
constant const size_t dst_strides[dims], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
|
||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void copy_g_nd1<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t& src_stride, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void copy_g_nd2<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[2], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void copy_g_nd3<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[3], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_1")]] \
|
||||
[[kernel]] void copy_gg_nd1<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t& src_stride, \
|
||||
constant const size_t& dst_stride, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_2")]] \
|
||||
[[kernel]] void copy_gg_nd2<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[2], \
|
||||
constant const size_t dst_strides[2], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint2 index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_3")]] \
|
||||
[[kernel]] void copy_gg_nd3<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[3], \
|
||||
constant const size_t dst_strides[3], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]); \
|
||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||
@@ -218,21 +218,21 @@ template <typename T, typename U>
|
||||
#define instantiate_copy_g(name, itype, otype) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void copy_g<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int* src_shape, \
|
||||
constant const size_t* src_strides, \
|
||||
constant const int& ndim, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name)]] \
|
||||
[[kernel]] void copy_gg<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int* src_shape, \
|
||||
constant const size_t* src_strides, \
|
||||
constant const size_t* dst_strides, \
|
||||
constant const int& ndim, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
|
@@ -14,3 +14,5 @@ static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
||||
static MTL_CONST constexpr int RMS_N_READS = 4;
|
||||
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
||||
|
89
mlx/backend/metal/kernels/expm1f.h
Normal file
89
mlx/backend/metal/kernels/expm1f.h
Normal file
@@ -0,0 +1,89 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
// Original license copied below:
|
||||
// Copyright (c) 2015-2023 Norbert Juffa
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions
|
||||
// are met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
/* Compute exponential base e minus 1. Maximum ulp error = 0.997458
|
||||
|
||||
i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1.
|
||||
Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5).
|
||||
With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy,
|
||||
when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r.
|
||||
|
||||
NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2)
|
||||
*/
|
||||
float expm1f_scaled_unchecked(float a, float b) {
|
||||
float f, j, r, s, t, u, v, x, y;
|
||||
int i;
|
||||
|
||||
// exp(a) = 2**i * exp(f); i = rintf (a / log(2))
|
||||
j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23
|
||||
j = j - 12582912.0f; // 0x1.8p23
|
||||
i = (int)j;
|
||||
f = fma(j, -6.93145752e-1f, a);
|
||||
|
||||
// approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
|
||||
s = f * f;
|
||||
if (a == 0.0f)
|
||||
s = a; // ensure -0 is passed through
|
||||
// err = 0.997458 ulp1 = 11081805
|
||||
r = 1.97350979e-4f; // 0x1.9de000p-13
|
||||
r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10
|
||||
r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7
|
||||
r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5
|
||||
r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3
|
||||
r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2
|
||||
u = (j == 1) ? (f + 0.5f) : f;
|
||||
v = fma(r, s, u);
|
||||
s = 0.5f * b;
|
||||
t = ldexp(s, i);
|
||||
y = t - s;
|
||||
x = (t - y) - s; // double-float canonicalization of difference
|
||||
r = fma(v, t, x) + y;
|
||||
r = r + r;
|
||||
if (j == 0)
|
||||
r = v;
|
||||
if (j == 1)
|
||||
r = v + v;
|
||||
return r;
|
||||
}
|
||||
|
||||
/* Compute exponential base e minus 1. max ulp err = 0.99746 */
|
||||
float expm1f(float a) {
|
||||
float r;
|
||||
|
||||
r = expm1f_scaled_unchecked(a, 1.0f);
|
||||
/* handle severe overflow and underflow */
|
||||
if (abs(a - 1.0f) > 88.0f) {
|
||||
r = fma(r, r, -1.0f);
|
||||
}
|
||||
return r;
|
||||
}
|
195
mlx/backend/metal/kernels/fft.metal
Normal file
195
mlx/backend/metal/kernels/fft.metal
Normal file
@@ -0,0 +1,195 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Metal FFT using Stockham's algorithm
|
||||
//
|
||||
// References:
|
||||
// - VkFFT (https://github.com/DTolm/VkFFT)
|
||||
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
||||
|
||||
#include <metal_math>
|
||||
#include <metal_common>
|
||||
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
float2 complex_mul(float2 a, float2 b) {
|
||||
float2 c;
|
||||
c.x = a.x * b.x - a.y * b.y;
|
||||
c.y = a.x * b.y + a.y * b.x;
|
||||
return c;
|
||||
}
|
||||
|
||||
float2 get_twiddle(int k, int p) {
|
||||
float theta = -1.0f * k * M_PI_F / (2*p);
|
||||
|
||||
float2 twiddle;
|
||||
twiddle.x = metal::fast::cos(theta);
|
||||
twiddle.y = metal::fast::sin(theta);
|
||||
return twiddle;
|
||||
}
|
||||
|
||||
// single threaded radix2 implemetation
|
||||
void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
|
||||
float2 twiddle = get_twiddle(k, p);
|
||||
|
||||
float2 z = complex_mul(x_1, twiddle);
|
||||
|
||||
float2 y_0 = x_0 + z;
|
||||
float2 y_1 = x_0 - z;
|
||||
|
||||
int j = (i << 1) - k;
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
}
|
||||
|
||||
// single threaded radix4 implemetation
|
||||
void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
float2 x_2 = read_buf[i + 2*m];
|
||||
float2 x_3 = read_buf[i + 3*m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
|
||||
float2 twiddle = get_twiddle(k, p);
|
||||
// e^a * e^b = e^(a + b)
|
||||
float2 twiddle_2 = complex_mul(twiddle, twiddle);
|
||||
float2 twiddle_3 = complex_mul(twiddle, twiddle_2);
|
||||
|
||||
x_1 = complex_mul(x_1, twiddle);
|
||||
x_2 = complex_mul(x_2, twiddle_2);
|
||||
x_3 = complex_mul(x_3, twiddle_3);
|
||||
|
||||
float2 minus_i;
|
||||
minus_i.x = 0;
|
||||
minus_i.y = -1;
|
||||
|
||||
// Hard coded twiddle factors for DFT4
|
||||
float2 z_0 = x_0 + x_2;
|
||||
float2 z_1 = x_0 - x_2;
|
||||
float2 z_2 = x_1 + x_3;
|
||||
float2 z_3 = complex_mul(x_1 - x_3, minus_i);
|
||||
|
||||
float2 y_0 = z_0 + z_2;
|
||||
float2 y_1 = z_1 + z_3;
|
||||
float2 y_2 = z_0 - z_2;
|
||||
float2 y_3 = z_1 - z_3;
|
||||
|
||||
int j = ((i - k) << 2) + k;
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
write_buf[j + 2*p] = y_2;
|
||||
write_buf[j + 3*p] = y_3;
|
||||
}
|
||||
|
||||
|
||||
// Each FFT is computed entirely in shared GPU memory.
|
||||
//
|
||||
// N is decomposed into radix-2 and radix-4 DFTs:
|
||||
// e.g. 128 = 2 * 4 * 4 * 4
|
||||
//
|
||||
// At each step we use n / 4 threads, each performing
|
||||
// a single-threaded radix-4 or radix-2 DFT.
|
||||
//
|
||||
// We provide the number of radix-2 and radix-4
|
||||
// steps at compile time for a ~20% performance boost.
|
||||
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
||||
[[kernel]] void fft(
|
||||
const device float2 *in [[buffer(0)]],
|
||||
device float2 * out [[buffer(1)]],
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]],
|
||||
uint3 threads_per_grid [[threads_per_grid]]) {
|
||||
|
||||
// Index of the DFT in batch
|
||||
int batch_idx = thread_position_in_grid.x * n;
|
||||
// The index in the DFT we're working on
|
||||
int i = thread_position_in_grid.y;
|
||||
// The number of the threads we're using for each DFT
|
||||
int m = threads_per_grid.y;
|
||||
|
||||
// Allocate 2 shared memory buffers for Stockham.
|
||||
// We alternate reading from one and writing to the other at each radix step.
|
||||
threadgroup float2 shared_in[n];
|
||||
threadgroup float2 shared_out[n];
|
||||
|
||||
// Pointers to facilitate Stockham buffer swapping
|
||||
threadgroup float2* read_buf = shared_in;
|
||||
threadgroup float2* write_buf = shared_out;
|
||||
threadgroup float2* tmp;
|
||||
|
||||
// Copy input into shared memory
|
||||
shared_in[i] = in[batch_idx + i];
|
||||
shared_in[i + m] = in[batch_idx + i + m];
|
||||
shared_in[i + 2*m] = in[batch_idx + i + 2*m];
|
||||
shared_in[i + 3*m] = in[batch_idx + i + 3*m];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
|
||||
for (size_t r = 0; r < radix_2_steps; r++) {
|
||||
radix2(i, p, m*2, read_buf, write_buf);
|
||||
radix2(i + m, p, m*2, read_buf, write_buf);
|
||||
p *= 2;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Stockham switch of buffers
|
||||
tmp = write_buf;
|
||||
write_buf = read_buf;
|
||||
read_buf = tmp;
|
||||
}
|
||||
|
||||
for (size_t r = 0; r < radix_4_steps; r++) {
|
||||
radix4(i, p, m, read_buf, write_buf);
|
||||
p *= 4;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Stockham switch of buffers
|
||||
tmp = write_buf;
|
||||
write_buf = read_buf;
|
||||
read_buf = tmp;
|
||||
}
|
||||
|
||||
// Copy shared memory to output
|
||||
out[batch_idx + i] = read_buf[i];
|
||||
out[batch_idx + i + m] = read_buf[i + m];
|
||||
out[batch_idx + i + 2*m] = read_buf[i + 2*m];
|
||||
out[batch_idx + i + 3*m] = read_buf[i + 3*m];
|
||||
}
|
||||
|
||||
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
|
||||
template [[host_name("fft_" #name)]] \
|
||||
[[kernel]] void fft<n, radix_2_steps, radix_4_steps>( \
|
||||
const device float2* in [[buffer(0)]], \
|
||||
device float2* out [[buffer(1)]], \
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]], \
|
||||
uint3 threads_per_grid [[threads_per_grid]]);
|
||||
|
||||
|
||||
// Explicitly define kernels for each power of 2.
|
||||
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
|
||||
instantiate_fft(8, 8, 1, 1)
|
||||
instantiate_fft(16, 16, 0, 2)
|
||||
instantiate_fft(32, 32, 1, 2)
|
||||
instantiate_fft(64, 64, 0, 3)
|
||||
instantiate_fft(128, 128, 1, 3)
|
||||
instantiate_fft(256, 256, 0, 4)
|
||||
instantiate_fft(512, 512, 1, 4)
|
||||
instantiate_fft(1024, 1024, 0, 5)
|
||||
// 2048 is the max that will fit into 32KB of threadgroup memory.
|
||||
// TODO: implement 4 step FFT for larger n.
|
||||
instantiate_fft(2048, 2048, 1, 5)
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
@@ -22,7 +22,8 @@ template <
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN > /* Thread cols (in elements) */
|
||||
const int TN , /* Thread cols (in elements) */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
struct GEMVKernel {
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||
@@ -48,11 +49,16 @@ struct GEMVKernel {
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat,
|
||||
const device T* in_vec,
|
||||
device T* out_vec,
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
@@ -81,7 +87,7 @@ struct GEMVKernel {
|
||||
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||
|
||||
// Advance matrix
|
||||
mat += out_row * in_vec_size;
|
||||
mat += out_row * marix_ld;
|
||||
|
||||
// Loop over in_vec in blocks of BN * TN
|
||||
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
||||
@@ -124,14 +130,14 @@ struct GEMVKernel {
|
||||
if(bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||
inter[tn] = mat[tm * marix_ld + bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||
inter[tn] = mat[tm * in_vec_size + col_idx];
|
||||
inter[tn] = mat[tm * marix_ld + col_idx];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,7 +160,13 @@ struct GEMVKernel {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
if(kDoAxpby) {
|
||||
out_vec[out_row + tm] =
|
||||
static_cast<T>(alpha) * result[tm] +
|
||||
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
|
||||
} else {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -172,7 +184,8 @@ template <
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN > /* Thread cols (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
struct GEMVTKernel {
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
@@ -197,11 +210,16 @@ struct GEMVTKernel {
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat,
|
||||
const device T* in_vec,
|
||||
device T* out_vec,
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
@@ -245,7 +263,7 @@ struct GEMVTKernel {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
@@ -257,7 +275,7 @@ struct GEMVTKernel {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
@@ -292,13 +310,17 @@ struct GEMVTKernel {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < TN; j++) {
|
||||
out_vec[out_col + j] = result[j];
|
||||
|
||||
if(kDoAxpby) {
|
||||
out_vec[out_col + j] =
|
||||
static_cast<T>(alpha) * result[j] +
|
||||
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
|
||||
} else {
|
||||
out_vec[out_col + j] = result[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -310,78 +332,64 @@ template <
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch, /* Batch ndim > 1 */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_nc(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides_vec [[buffer(7)]],
|
||||
const device size_t* nc_strides_mat [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
|
||||
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
|
||||
if(kDoNCBatch) {
|
||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||
|
||||
if(kDoAxpby) {
|
||||
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
||||
}
|
||||
|
||||
} else {
|
||||
in_vec += tid.z * vector_batch_stride[0];
|
||||
mat += tid.z * matrix_batch_stride[0];
|
||||
|
||||
if(kDoAxpby) {
|
||||
bias += tid.z * bias_batch_stride[0];
|
||||
}
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
bias,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
alpha,
|
||||
beta,
|
||||
bias_stride,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
@@ -392,41 +400,34 @@ template <
|
||||
}
|
||||
|
||||
|
||||
#define instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \
|
||||
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
|
||||
[[kernel]] void gemv<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_nc(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
|
||||
[[kernel]] void gemv_nc<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides_vec [[buffer(7)]], \
|
||||
const device size_t* nc_strides_mat [[buffer(8)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_nc(name, itype, bm, bn, tm, tn)
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
||||
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
||||
@@ -446,77 +447,64 @@ template <
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch, /* Batch ndim > 1 */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t_nc(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides_vec [[buffer(7)]],
|
||||
const device size_t* nc_strides_mat [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
|
||||
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
|
||||
if(kDoNCBatch) {
|
||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||
|
||||
if(kDoAxpby) {
|
||||
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
||||
}
|
||||
|
||||
} else {
|
||||
in_vec += tid.z * vector_batch_stride[0];
|
||||
mat += tid.z * matrix_batch_stride[0];
|
||||
|
||||
if(kDoAxpby) {
|
||||
bias += tid.z * bias_batch_stride[0];
|
||||
}
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
bias,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
alpha,
|
||||
beta,
|
||||
bias_stride,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
@@ -526,41 +514,34 @@ template <
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \
|
||||
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
|
||||
[[kernel]] void gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
|
||||
[[kernel]] void gemv_t_nc<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides_vec [[buffer(7)]], \
|
||||
const device size_t* nc_strides_mat [[buffer(8)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn)
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
||||
|
||||
#define instantiate_gemv_t_blocks(name, itype) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||
|
553
mlx/backend/metal/kernels/layer_norm.metal
Normal file
553
mlx/backend/metal/kernels/layer_norm.metal
Normal file
@@ -0,0 +1,553 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void layer_norm_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* b,
|
||||
device T* out,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
constant uint& b_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float thread_x[N_READS];
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
b += b_stride * lid * N_READS;
|
||||
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumx += thread_x[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumx += thread_x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
|
||||
// Write the outputs
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void layer_norm_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* b,
|
||||
device T* out,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
constant uint& b_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
b += b_stride * lid * N_READS;
|
||||
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
sumx2 += xi * xi;
|
||||
sumx += xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
sumx2 += xi * xi;
|
||||
sumx += xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
|
||||
// Write the outputs
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (x[r + i] - mean) * normalizer;
|
||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = (x[r + i] - mean) * normalizer;
|
||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_layer_norm_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the computation and accumulators
|
||||
float thread_x[N_READS];
|
||||
float thread_w[N_READS];
|
||||
float thread_g[N_READS];
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float sumwg = 0;
|
||||
float sumwgx = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumwg[SIMD_SIZE];
|
||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meanwg[1];
|
||||
threadgroup float local_meanwgx[1];
|
||||
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
thread_g[i] = g[i];
|
||||
float wg = thread_w[i] * thread_g[i];
|
||||
sumx += thread_x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumwg += wg;
|
||||
sumwgx += wg * thread_x[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
thread_g[i] = g[i];
|
||||
float wg = thread_w[i] * thread_g[i];
|
||||
sumx += thread_x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumwg += wg;
|
||||
sumwgx += wg * thread_x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumwg = simd_sum(sumwg);
|
||||
sumwgx = simd_sum(sumwgx);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumwg[simd_lane_id] = 0;
|
||||
local_sumwgx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumwg[simd_group_id] = sumwg;
|
||||
local_sumwgx[simd_group_id] = sumwgx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
local_meanwg[0] = sumwg / axis_size;
|
||||
local_meanwgx[0] = sumwgx / axis_size;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float meanwg = local_meanwg[0];
|
||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
||||
float normalizer2 = normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_layer_norm_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the accumulators
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float sumwg = 0;
|
||||
float sumwgx = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumwg[SIMD_SIZE];
|
||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meanwg[1];
|
||||
threadgroup float local_meanwgx[1];
|
||||
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
sumx += xi;
|
||||
sumx2 += xi * xi;
|
||||
sumwg += wg;
|
||||
sumwgx += wg * xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
sumx += xi;
|
||||
sumx2 += xi * xi;
|
||||
sumwg += wg;
|
||||
sumwgx += wg * xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumwg = simd_sum(sumwg);
|
||||
sumwgx = simd_sum(sumwgx);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumwg[simd_lane_id] = 0;
|
||||
local_sumwgx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumwg[simd_group_id] = sumwg;
|
||||
local_sumwgx[simd_group_id] = sumwgx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
local_meanwg[0] = sumwg / axis_size;
|
||||
local_meanwgx[0] = sumwgx / axis_size;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float meanwg = local_meanwg[0];
|
||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
||||
float normalizer2 = normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (x[i + r] - mean) * normalizer;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
||||
xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = (x[i + r] - mean) * normalizer;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
||||
xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_layer_norm_single_row(name, itype) \
|
||||
template [[host_name("layer_norm" #name)]] [[kernel]] void \
|
||||
layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm_looped(name, itype) \
|
||||
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
|
||||
layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gb, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm(name, itype) \
|
||||
instantiate_layer_norm_single_row(name, itype) \
|
||||
instantiate_layer_norm_looped(name, itype)
|
||||
|
||||
instantiate_layer_norm(float32, float)
|
||||
instantiate_layer_norm(float16, half)
|
||||
instantiate_layer_norm(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
||||
|
@@ -15,14 +15,6 @@ using namespace metal;
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <typename T> struct AccT {
|
||||
typedef T acc_t;
|
||||
};
|
||||
|
||||
template <> struct AccT<bfloat16_t> {
|
||||
typedef float acc_t;
|
||||
};
|
||||
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector(const device T *x, thread U *x_thread) {
|
||||
@@ -60,6 +52,51 @@ inline U load_vector(const device T *x, thread U *x_thread) {
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U sum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 4.0f;
|
||||
x_thread[i+2] = x[i+2] / 16.0f;
|
||||
x_thread[i+3] = x[i+3] / 64.0f;
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 16.0f;
|
||||
x_thread[i+2] = x[i+2] / 256.0f;
|
||||
x_thread[i+3] = x[i+3] / 4096.0f;
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
sum += x[i];
|
||||
x_thread[i] = x[i];
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
@@ -96,6 +133,74 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum, int N) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U accum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (w[i] & 0x03)
|
||||
+ x_thread[4*i+1] * (w[i] & 0x0c)
|
||||
+ x_thread[4*i+2] * (w[i] & 0x30)
|
||||
+ x_thread[4*i+3] * (w[i] & 0xc0));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (ws[i] & 0x000f)
|
||||
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
|
||||
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
|
||||
+ x_thread[4*i+3] * (ws[i] & 0xf000));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
accum += x_thread[i] * w[i];
|
||||
}
|
||||
}
|
||||
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
if (bits == 2) {
|
||||
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
result[4*i] += x * (s[0] * (w[i] & 0x03) + bias);
|
||||
result[4*i+1] += x * (s[1] * (w[i] & 0x0c) + bias);
|
||||
result[4*i+2] += x * (s[2] * (w[i] & 0x30) + bias);
|
||||
result[4*i+3] += x * (s[3] * (w[i] & 0xc0) + bias);
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const thread uint16_t* ws = (const thread uint16_t*)w;
|
||||
U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f};
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
result[4*i] += x * (s[0] * (ws[i] & 0x000f) + bias);
|
||||
result[4*i+1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
|
||||
result[4*i+2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
|
||||
result[4*i+3] += x * (s[3] * (ws[i] & 0xf000) + bias);
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < values_per_thread; i++) {
|
||||
result[i] += x * (scale * w[i] + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
[[kernel]] void qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
@@ -204,7 +309,8 @@ template <typename T, const int group_size, const int bits>
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
int k = 0;
|
||||
for (; k < in_vec_size-block_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
@@ -222,6 +328,18 @@ template <typename T, const int group_size, const int bits>
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
@@ -239,7 +357,8 @@ template <typename T, const int group_size, const int bits>
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + used_out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
int k = 0;
|
||||
for (; k < in_vec_size-block_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
@@ -257,6 +376,18 @@ template <typename T, const int group_size, const int bits>
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot_safe<U, values_per_thread, bits>(wl, x_thread, s, b, sum, remaining);
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
@@ -268,7 +399,7 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
@@ -278,39 +409,28 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BM == SIMD_SIZE, "qvm expects BM to be equal to SIMD_SIZE");
|
||||
static_assert(BN == BM, "qvm expects a block size of 32x32");
|
||||
constexpr int num_simdgroups = 8;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int blocksize = SIMD_SIZE;
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_int = 32 / bits;
|
||||
constexpr int colgroup = BN * el_per_int;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
|
||||
typedef typename AccT<T>::acc_t U;
|
||||
threadgroup U scales_block[BM * groups_per_block];
|
||||
threadgroup U biases_block[BM * groups_per_block];
|
||||
threadgroup U x_block[BM];
|
||||
typedef float U;
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread U result[el_per_int] = {0};
|
||||
thread U result[pack_factor] = {0};
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_local = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||
const int out_vec_size_w = out_vec_size / pack_factor;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col_start = tid.y * (BN * el_per_int);
|
||||
int out_col = out_col_start + simd_gid * el_per_int;
|
||||
w += out_col / el_per_int;
|
||||
scales += out_col_start / group_size;
|
||||
biases += out_col_start / group_size;
|
||||
int out_col = tid.y * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
|
||||
w += out_col / pack_factor;
|
||||
scales += out_col / group_size;
|
||||
biases += out_col / group_size;
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size + out_col;
|
||||
|
||||
@@ -318,53 +438,39 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
return;
|
||||
}
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=BM) {
|
||||
int offset_lid = simd_lid + i;
|
||||
int offset_gid = simd_gid + i;
|
||||
bool thread_in_bounds = offset_lid < in_vec_size;
|
||||
bool group_in_bounds = offset_gid < in_vec_size;
|
||||
// Loop over in_vec in blocks of blocksize
|
||||
int i = 0;
|
||||
for (; i + blocksize <= in_vec_size; i += blocksize) {
|
||||
x_local = x[i + simd_lid];
|
||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
|
||||
}
|
||||
|
||||
// Load the scales and biases to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lid < groups_per_block && group_in_bounds) {
|
||||
scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
|
||||
biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load in_vec, scale, bias to registers
|
||||
x_local = x_block[simd_lid];
|
||||
scale = scales_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
|
||||
w_local >>= bits;
|
||||
}
|
||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
||||
}
|
||||
if (static_cast<int>(i + simd_lid) < in_vec_size) {
|
||||
x_local = x[i + simd_lid];
|
||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
} else {
|
||||
x_local = 0;
|
||||
scale = 0;
|
||||
bias = 0;
|
||||
w_local = 0;
|
||||
}
|
||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
for (int k=0; k<pack_factor; k++) {
|
||||
result[k] = simd_sum(result[k]);
|
||||
}
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
for (int k=0; k<pack_factor; k++) {
|
||||
y[k] = static_cast<T>(result[k]);
|
||||
}
|
||||
}
|
||||
@@ -414,6 +520,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
const int K_g = K / group_size;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
|
||||
x += y_row * K;
|
||||
w += y_col * K_w;
|
||||
scales += y_col * K_g;
|
||||
@@ -466,7 +573,10 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
|
||||
if (y_row + offset_row < N) {
|
||||
// y_col corresponds to the row of the weight matrix and added to
|
||||
// offset_row it should be less than the total number of rows
|
||||
// otherwise skip.
|
||||
if (y_col + offset_row < N) {
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
@@ -619,7 +729,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
||||
|
||||
if (y_row + offset_row < K) {
|
||||
if (k + offset_row < K) {
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
@@ -738,7 +848,7 @@ instantiate_qmv_types( 32, 8)
|
||||
|
||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qvm<itype, 32, 32, group_size, bits>( \
|
||||
[[kernel]] void qvm<itype, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
@@ -747,7 +857,6 @@ instantiate_qmv_types( 32, 8)
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
|
@@ -6,6 +6,69 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Small column reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void col_reduce_small(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
|
||||
// Appease the compiler
|
||||
(void)out_size;
|
||||
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
|
||||
auto out_idx = tid;
|
||||
|
||||
in += elem_to_loc(
|
||||
out_idx,
|
||||
shape + non_col_ndim,
|
||||
strides + non_col_ndim,
|
||||
ndim - non_col_ndim);
|
||||
|
||||
for(uint i = 0; i < non_col_reductions; i++) {
|
||||
size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
|
||||
|
||||
for(uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
|
||||
U val = static_cast<U>(in[in_idx]);
|
||||
total_val = op(total_val, val);
|
||||
}
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_small_" #name)]] \
|
||||
[[kernel]] void col_reduce_small<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(8)]], \
|
||||
const constant int* non_col_shapes [[buffer(9)]], \
|
||||
const constant size_t* non_col_strides [[buffer(10)]], \
|
||||
const constant int& non_col_ndim [[buffer(11)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -171,9 +234,11 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
@@ -181,4 +246,8 @@ instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
|
||||
|
||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or)
|
@@ -108,15 +108,17 @@ template <typename T, typename U, typename Op>
|
||||
const short i_ed = short(reduction_size);
|
||||
const short i_jump = reductions_per_thread;
|
||||
|
||||
for(short r = r_st; r < r_ed; r += r_jump) {
|
||||
if(r_st < r_jump) {
|
||||
for(short r = r_st; r < r_ed; r += r_jump) {
|
||||
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
|
||||
for(short i = i_st; i < i_ed; i += i_jump) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
|
||||
for(short i = i_st; i < i_ed; i += i_jump) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
435
mlx/backend/metal/kernels/rms_norm.metal
Normal file
435
mlx/backend/metal/kernels/rms_norm.metal
Normal file
@@ -0,0 +1,435 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void rms_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
device T* out,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||
threadgroup float* local_sums [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float acc = 0;
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i];
|
||||
acc += xi * xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i];
|
||||
acc += xi * xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
acc = simd_sum(acc);
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sums[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sums[simd_group_id] = acc;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
acc = simd_sum(local_sums[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write the outputs
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void rms_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
device T* out,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||
threadgroup float* local_sums [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float acc = 0;
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
acc += xi * xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
acc += xi * xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
acc = simd_sum(acc);
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sums[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sums[simd_group_id] = acc;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
acc = simd_sum(local_sums[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write the outputs
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[r + i] = w[w_stride * (i + r)] *
|
||||
static_cast<T>(x[r + i] * local_inv_mean[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
out[r + i] = w[w_stride * (i + r)] *
|
||||
static_cast<T>(x[r + i] * local_inv_mean[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_rms_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the computation and accumulators
|
||||
float thread_x[N_READS];
|
||||
float thread_w[N_READS];
|
||||
float thread_g[N_READS];
|
||||
float sumx2 = 0;
|
||||
float sumgwx = 0;
|
||||
|
||||
// Allocate shared memory to implement the reduction
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumgwx[SIMD_SIZE];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meangwx[1];
|
||||
|
||||
// Read and accumulate locally
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[w_stride * i];
|
||||
thread_g[i] = g[i];
|
||||
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[w_stride * i];
|
||||
thread_g[i] = g[i];
|
||||
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate across threads
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumgwx = simd_sum(sumgwx);
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumgwx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumgwx[simd_group_id] = sumgwx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_meangwx[0] = sumgwx / axis_size;
|
||||
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float meangwx = local_meangwx[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_rms_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the accumulators
|
||||
float sumx2 = 0;
|
||||
float sumgwx = 0;
|
||||
|
||||
// Allocate shared memory to implement the reduction
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumgwx[SIMD_SIZE];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meangwx[1];
|
||||
|
||||
// Read and accumulate locally
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
sumx2 += xi * xi;
|
||||
sumgwx += xi * wi * gi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
sumx2 += xi * xi;
|
||||
sumgwx += xi * wi * gi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate across threads
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumgwx = simd_sum(sumgwx);
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumgwx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumgwx[simd_group_id] = sumgwx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_meangwx[0] = sumgwx / axis_size;
|
||||
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float meangwx = local_meangwx[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rms_single_row(name, itype) \
|
||||
template [[host_name("rms" #name)]] [[kernel]] void \
|
||||
rms_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
|
||||
vjp_rms_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms_looped(name, itype) \
|
||||
template [[host_name("rms_looped" #name)]] [[kernel]] void \
|
||||
rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
|
||||
vjp_rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms(name, itype) \
|
||||
instantiate_rms_single_row(name, itype) \
|
||||
instantiate_rms_looped(name, itype)
|
||||
|
||||
instantiate_rms(float32, float)
|
||||
instantiate_rms(float16, half)
|
||||
instantiate_rms(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
@@ -5,11 +5,12 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional>
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope(
|
||||
const device T *in [[buffer(0)]],
|
||||
device T * out [[buffer(1)]],
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
constant const float& scale,
|
||||
@@ -19,13 +20,13 @@ template <typename T, bool traditional>
|
||||
uint in_index_1, in_index_2;
|
||||
uint out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
|
||||
out_index_2 = out_index_1 + grid.x;
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
}
|
||||
@@ -42,27 +43,41 @@ template <typename T, bool traditional>
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1 = x1 * costheta - x2 * sintheta;
|
||||
float rx2 = x1 * sintheta + x2 * costheta;
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
#define instantiate_rope(name, type, traditional) \
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] \
|
||||
[[kernel]] void rope<type, traditional>( \
|
||||
[[kernel]] void rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
instantiate_rope(traditional_float16, half, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
|
||||
instantiate_rope(traditional_float32, float, true)
|
||||
instantiate_rope(float16, half, false)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false)
|
||||
instantiate_rope(float32, float, false)
|
||||
instantiate_rope(traditional_float16, half, true, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||
instantiate_rope(traditional_float32, float, true, true)
|
||||
instantiate_rope(float16, half, false, true)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false, true)
|
||||
instantiate_rope(float32, float, false, true)
|
||||
instantiate_rope(vjp_traditional_float16, half, true, false)
|
||||
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
|
||||
instantiate_rope(vjp_traditional_float32, float, true, false)
|
||||
instantiate_rope(vjp_float16, half, false, false)
|
||||
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
||||
instantiate_rope(vjp_float32, float, false, false)
|
||||
|
@@ -451,7 +451,7 @@ instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSu
|
||||
//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
|
||||
instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4)
|
||||
instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum)
|
||||
//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
|
||||
instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4)
|
||||
@@ -464,7 +464,7 @@ instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumP
|
||||
//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
|
||||
instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4)
|
||||
instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd)
|
||||
//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
|
||||
instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4)
|
||||
@@ -477,7 +477,7 @@ instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMa
|
||||
//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
|
||||
instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4)
|
||||
instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4)
|
||||
//instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
|
||||
//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax)
|
||||
//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
|
||||
instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4)
|
||||
@@ -490,5 +490,5 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi
|
||||
//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
|
||||
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||
//instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)
|
||||
|
@@ -20,7 +20,6 @@ METAL_FUNC void scatter_1d_index_impl(
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]],
|
||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
@@ -33,11 +32,7 @@ METAL_FUNC void scatter_1d_index_impl(
|
||||
out_idx += idx_val * out_strides[i];
|
||||
}
|
||||
|
||||
if (!upd_col_contiguous) {
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
||||
} else {
|
||||
op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x);
|
||||
}
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
||||
}
|
||||
|
||||
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
|
||||
@@ -48,7 +43,6 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 gid [[thread_position_in_grid]]) { \
|
||||
\
|
||||
@@ -60,7 +54,6 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
out_shape, \
|
||||
out_strides, \
|
||||
upd_size, \
|
||||
upd_col_contiguous, \
|
||||
idx_buffers, \
|
||||
gid); \
|
||||
\
|
||||
@@ -195,7 +188,6 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
|
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
@@ -12,46 +11,48 @@ using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
inline T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause it is gonna be x
|
||||
// will be in (-oo, 0] anyway and subsequently it will be divided by
|
||||
// sum(exp(x_i)).
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return fast::exp(x);
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_single_row(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
|
||||
T ld[N_READS];
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
AccT ld[N_READS];
|
||||
|
||||
in += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
ld[i] = in[i];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = AccT(in[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] =
|
||||
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
local_max[simd_lane_id] = Limits<T>::finite_min;
|
||||
local_max[simd_lane_id] = Limits<AccT>::finite_min;
|
||||
local_normalizer[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Get the max
|
||||
T maxval = Limits<T>::finite_min;
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
||||
}
|
||||
@@ -70,9 +71,9 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
T normalizer = 0;
|
||||
AccT normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
T exp_x = softmax_exp(ld[i] - maxval);
|
||||
AccT exp_x = softmax_exp(ld[i] - maxval);
|
||||
ld[i] = exp_x;
|
||||
normalizer += exp_x;
|
||||
}
|
||||
@@ -93,25 +94,23 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
// Normalize and write to the output
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_looped(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
@@ -119,22 +118,27 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
in += gid * axis_size;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
// Get the max and the normalizer in one go
|
||||
T prevmax;
|
||||
T maxval = Limits<T>::finite_min;
|
||||
T normalizer = 0;
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
T vals[N_READS];
|
||||
AccT vals[N_READS];
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[offset + i];
|
||||
vals[i] = AccT(in[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] =
|
||||
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
|
||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
@@ -180,49 +184,66 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (offset + i < axis_size) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
out[offset + i] =
|
||||
T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_softmax_single_row(name, itype) \
|
||||
// clang-format off
|
||||
#define instantiate_softmax(name, itype) \
|
||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax_looped(name, itype) \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax(name, itype) \
|
||||
instantiate_softmax_single_row(name, itype) \
|
||||
instantiate_softmax_looped(name, itype)
|
||||
#define instantiate_softmax_precise(name, itype) \
|
||||
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
||||
instantiate_softmax(float32, float)
|
||||
instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
||||
instantiate_softmax_precise(float16, half)
|
||||
instantiate_softmax_precise(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
||||
|
@@ -394,7 +394,7 @@ struct Conv2DWeightBlockLoader {
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
: src_ld(params_ -> wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
|
@@ -244,7 +244,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
: src_ld(params_ -> wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
|
@@ -220,7 +220,7 @@ struct Conv2DWeightBlockLoaderGeneral {
|
||||
const short base_ww_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
: src_ld(params_ -> wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
|
@@ -140,7 +140,7 @@ struct GEMMKernel {
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device U* C [[buffer(2)]],
|
||||
device U* D [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
@@ -167,7 +167,7 @@ struct GEMMKernel {
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
C += c_row * params->ldc + c_col;
|
||||
D += c_row * params->ldd + c_col;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
@@ -214,7 +214,7 @@ struct GEMMKernel {
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, params->ldc);
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
}
|
||||
@@ -237,7 +237,7 @@ struct GEMMKernel {
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result(C, params->ldc);
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
@@ -252,7 +252,7 @@ struct GEMMKernel {
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
@@ -267,7 +267,7 @@ struct GEMMKernel {
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else {
|
||||
@@ -282,7 +282,7 @@ struct GEMMKernel {
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
@@ -23,8 +24,10 @@ template <typename T,
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device T *C [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
device T *D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -36,12 +39,25 @@ template <typename T,
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Adjust for batch
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += params->batch_stride_c * tid.z;
|
||||
if(params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, C,
|
||||
A, B, D,
|
||||
params,
|
||||
As, Bs,
|
||||
simd_lane_id, simd_group_id, tid, lid
|
||||
@@ -57,8 +73,10 @@ template <typename T,
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device itype *C [[buffer(2)]], \
|
||||
const constant GEMMParams* params [[buffer(3)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
|
@@ -27,7 +27,10 @@ template <typename T,
|
||||
const device T *B [[buffer(1)]],
|
||||
const device T *C [[buffer(2)]],
|
||||
device T *D [[buffer(3)]],
|
||||
const constant GEMMAddMMParams* params [[buffer(4)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@@ -50,9 +53,24 @@ template <typename T,
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Adjust for batch
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += params->batch_stride_c * tid.z;
|
||||
if(params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
|
||||
ulong3 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, C_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
C += batch_offsets.z;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
@@ -71,9 +89,10 @@ template <typename T,
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
C += c_row * params->ldc + c_col * params->fdc;
|
||||
D += c_row * params->ldd + c_col;
|
||||
|
||||
C += c_row * addmm_params->ldc + c_col * addmm_params->fdc;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
@@ -83,7 +102,7 @@ template <typename T,
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
const Epilogue epilogue_op(params->alpha, params->beta);
|
||||
const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
@@ -121,7 +140,7 @@ template <typename T,
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
}
|
||||
@@ -145,7 +164,7 @@ template <typename T,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, K_aligned>{});
|
||||
|
||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
@@ -163,7 +182,7 @@ template <typename T,
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
C, addmm_params->ldc, addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
@@ -182,7 +201,7 @@ template <typename T,
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
C, addmm_params->ldc, addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
@@ -201,7 +220,7 @@ template <typename T,
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
C, addmm_params->ldc, addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
}
|
||||
@@ -219,7 +238,10 @@ template <typename T,
|
||||
const device itype *B [[buffer(1)]], \
|
||||
const device itype *C [[buffer(2)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMAddMMParams* params [[buffer(4)]], \
|
||||
const constant GEMMParams* gemm_params [[buffer(4)]], \
|
||||
const constant GEMMAddMMParams* params [[buffer(5)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user