mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 14:34:37 +08:00
Compare commits
66 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
bddf23f175 | ||
![]() |
039da779d1 | ||
![]() |
d88d2124b5 | ||
![]() |
e142aaf8a1 | ||
![]() |
0caf35f4b8 | ||
![]() |
3fc993f82d | ||
![]() |
741eb28443 | ||
![]() |
1a87dc5ea8 | ||
![]() |
2427fa171e | ||
![]() |
639e06e1f3 | ||
![]() |
02fedbf1da | ||
![]() |
110d9b149d | ||
![]() |
9cbff5ec1d | ||
![]() |
433c0206b0 | ||
![]() |
8915901966 | ||
![]() |
f48bc496c7 | ||
![]() |
913b19329c | ||
![]() |
d8cb3128f6 | ||
![]() |
5f9ba3019f | ||
![]() |
46caf0bef0 | ||
![]() |
45f636e759 | ||
![]() |
a7b404ff53 | ||
![]() |
c4fd0e5ede | ||
![]() |
bab5386306 | ||
![]() |
aca7584635 | ||
![]() |
d611251502 | ||
![]() |
f30b659291 | ||
![]() |
90dfa43ff1 | ||
![]() |
dc175f08d3 | ||
![]() |
29221fa238 | ||
![]() |
a789685c63 | ||
![]() |
240d10699c | ||
![]() |
925014b661 | ||
![]() |
5611e1a95e | ||
![]() |
570f2bf29e | ||
![]() |
9948eddf11 | ||
![]() |
a3ee03da01 | ||
![]() |
28fcd2b519 | ||
![]() |
8e686764ac | ||
![]() |
479051ce1c | ||
![]() |
bfb5bad4f0 | ||
![]() |
1e16331d9c | ||
![]() |
be98f4ab6b | ||
![]() |
6ee1112f30 | ||
![]() |
8e5a5a1ccd | ||
![]() |
fcda3a0e66 | ||
![]() |
9663c22fe9 | ||
![]() |
f0ae00da12 | ||
![]() |
44390bd3d0 | ||
![]() |
2225374060 | ||
![]() |
105d236889 | ||
![]() |
53e6a9367c | ||
![]() |
f5a1582fe8 | ||
![]() |
a54f06b16f | ||
![]() |
4650d94d98 | ||
![]() |
a5681ebc52 | ||
![]() |
e849b3424a | ||
![]() |
b219d12a6b | ||
![]() |
cec8661113 | ||
![]() |
73a8c090e0 | ||
![]() |
db6796ac61 | ||
![]() |
9a8ee00246 | ||
![]() |
d39ed54f8e | ||
![]() |
16546c70d8 | ||
![]() |
eaba55c9bf | ||
![]() |
19ec023256 |
@@ -31,8 +31,7 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install pybind11-stubgen
|
||||
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
@@ -44,7 +43,8 @@ jobs:
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
python3 setup.py generate_stubs
|
||||
echo "stubs"
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@@ -80,8 +80,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install pybind11-stubgen
|
||||
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
@@ -95,7 +94,7 @@ jobs:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@@ -144,9 +143,8 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
pip install --upgrade setuptools
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
pip install twine
|
||||
pip install build
|
||||
@@ -161,7 +159,7 @@ jobs:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
@@ -209,9 +207,8 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
|
||||
pip install --upgrade setuptools
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
@@ -219,7 +216,7 @@ jobs:
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
pip install . -v
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python -m build --wheel
|
||||
|
@@ -15,6 +15,8 @@ MLX was developed with contributions from the following individuals:
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
@@ -15,31 +15,33 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.7.0)
|
||||
set(MLX_VERSION 0.9.1)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
||||
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
||||
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
||||
|
||||
set(MLX_BUILD_ARM OFF)
|
||||
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
|
||||
message(FATAL_ERROR
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, check the build"
|
||||
" documentation for possible fixes: "
|
||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
||||
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
message(WARNING
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, "
|
||||
" make sure you are building for arm64.")
|
||||
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
if(NOT MLX_ENABLE_X64_MAC)
|
||||
message(FATAL_ERROR
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, check the build"
|
||||
" documentation for possible fixes: "
|
||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
||||
else()
|
||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||
endif()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
set(MLX_BUILD_ARM ON)
|
||||
endif()
|
||||
|
||||
@@ -64,8 +66,14 @@ endif()
|
||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
set(MLX_METAL_DEBUG OFF)
|
||||
elseif (MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
|
||||
if (MLX_METAL_DEBUG)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION
|
||||
@@ -108,7 +116,27 @@ if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
#set(BLA_VENDOR Generic)
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
set(BLA_VENDOR OpenBLAS)
|
||||
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||
endif()
|
||||
# Search and link with lapack.
|
||||
find_package(LAPACK REQUIRED)
|
||||
if (NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
/usr/local/opt/openblas/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old version
|
||||
# of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
if (NOT BLAS_FOUND)
|
||||
message(FATAL_ERROR "Must have BLAS installed")
|
||||
@@ -122,17 +150,6 @@ else()
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||
find_package(LAPACK REQUIRED)
|
||||
if (NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
@@ -146,8 +163,12 @@ target_include_directories(
|
||||
|
||||
if (MLX_BUILD_PYTHON_BINDINGS)
|
||||
message(STATUS "Building Python bindings.")
|
||||
find_package(Python COMPONENTS Interpreter Development)
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||
endif()
|
||||
|
||||
|
@@ -24,7 +24,7 @@
|
||||
<< std::endl;
|
||||
|
||||
template <typename F, typename... Args>
|
||||
double time_fn(F fn, Args... args) {
|
||||
double time_fn(F fn, Args&&... args) {
|
||||
// warmup
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
eval(fn(std::forward<Args>(args)...));
|
||||
|
41
benchmarks/python/layer_norm_bench.py
Normal file
41
benchmarks/python/layer_norm_bench.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def layer_norm(x, w, b, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
mu = mx.mean(x, -1, keepdims=True)
|
||||
v = mx.var(x, -1, keepdims=True)
|
||||
return (x - mu) * mx.rsqrt(v + eps) * w + b
|
||||
|
||||
|
||||
def time_layer_norm():
|
||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x, w, b):
|
||||
gx, gw, gb = x, w, b
|
||||
for _ in range(32):
|
||||
gx, gw, gb = g(gx, gw, gb, y)
|
||||
return gx, gw, gb
|
||||
|
||||
time_fn(layer_norm_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_layer_norm()
|
39
benchmarks/python/rms_norm_bench.py
Normal file
39
benchmarks/python/rms_norm_bench.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def rms_norm(x, w, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return (x * n).astype(ot) * w
|
||||
|
||||
|
||||
def time_rms_norm():
|
||||
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1))
|
||||
g2 = mx.grad(f2, argnums=(0, 1))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, y)
|
||||
|
||||
def rms_norm_loop(g, x, w):
|
||||
gx, gw = x, w
|
||||
for _ in range(32):
|
||||
gx, gw = g(gx, gw, y)
|
||||
return gx, gw
|
||||
|
||||
time_fn(rms_norm_loop, g1, x, w)
|
||||
time_fn(rms_norm_loop, g2, x, w)
|
||||
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
||||
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_rms_norm()
|
@@ -6,21 +6,21 @@ from time_utils import time_fn
|
||||
|
||||
|
||||
def time_rope():
|
||||
rope = nn.RoPE(4096)
|
||||
rope = nn.RoPE(64)
|
||||
|
||||
# vec
|
||||
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
|
||||
mx.eval(x)
|
||||
|
||||
def rope_vec(x):
|
||||
for _ in range(32):
|
||||
x = rope(x)
|
||||
x = rope(x, offset=100)
|
||||
return x
|
||||
|
||||
time_fn(rope_vec, x)
|
||||
|
||||
# matrix
|
||||
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
|
||||
mx.eval(x)
|
||||
|
||||
def rope_mat(x):
|
||||
|
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 746 KiB |
@@ -29,16 +29,17 @@ autosummary_generate = True
|
||||
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
|
||||
|
||||
intersphinx_mapping = {
|
||||
"https://docs.python.org/3": None,
|
||||
"https://numpy.org/doc/stable/": None,
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"numpy": ("https://numpy.org/doc/stable/", None),
|
||||
}
|
||||
|
||||
templates_path = ["_templates"]
|
||||
html_static_path = ["_static"]
|
||||
source_suffix = ".rst"
|
||||
master_doc = "index"
|
||||
main_doc = "index"
|
||||
highlight_language = "python"
|
||||
pygments_style = "sphinx"
|
||||
add_module_names = False
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
@@ -59,3 +60,22 @@ html_theme_options = {
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
htmlhelp_basename = "mlx_doc"
|
||||
|
||||
|
||||
def setup(app):
|
||||
from sphinx.util import inspect
|
||||
|
||||
wrapped_isfunc = inspect.isfunction
|
||||
|
||||
def isfunc(obj):
|
||||
type_name = str(type(obj))
|
||||
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
|
||||
return True
|
||||
return wrapped_isfunc(obj)
|
||||
|
||||
inspect.isfunction = isfunc
|
||||
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
|
||||
|
@@ -223,7 +223,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
|
52
docs/src/dev/metal_debugger.rst
Normal file
52
docs/src/dev/metal_debugger.rst
Normal file
@@ -0,0 +1,52 @@
|
||||
Metal Debugger
|
||||
==============
|
||||
|
||||
Profiling is a key step for performance optimization. You can build MLX with
|
||||
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and optimization
|
||||
workflow. The ``MLX_METAL_DEBUG`` debug option:
|
||||
|
||||
* Records source during Metal compilation, for later inspection while
|
||||
debugging.
|
||||
* Labels Metal objects such as command queues, improving capture readability.
|
||||
|
||||
The ``metal::start_capture`` function initiates a capture of all MLX GPU work.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
int main() {
|
||||
metal::start_capture("/Users/Jane/Developer/MLX.gputrace");
|
||||
|
||||
auto a = arange(10.f, 20.f, 1.f, float32);
|
||||
auto b = arange(30.f, 40.f, 1.f, float32);
|
||||
auto c = add(a, b);
|
||||
|
||||
eval(c);
|
||||
|
||||
metal::stop_capture();
|
||||
}
|
||||
|
||||
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
|
||||
has a great overview of all operations. Checkout the `Metal debugger
|
||||
documentation`_ for more information.
|
||||
|
||||
.. image:: ../_static/metal_debugger/capture.png
|
||||
:class: dark-light
|
||||
|
||||
Xcode Workflow
|
||||
--------------
|
||||
|
||||
You can skip saving to a path by running within Xcode. First, generate an Xcode
|
||||
project using CMake.
|
||||
|
||||
.. code-block::
|
||||
|
||||
mkdir build && cd build
|
||||
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
|
||||
open mlx.xcodeproj
|
||||
|
||||
Select the ``metal_capture`` example schema and run.
|
||||
|
||||
.. image:: ../_static/metal_debugger/schema.png
|
||||
:class: dark-light
|
||||
|
||||
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger
|
@@ -58,10 +58,12 @@ are the CPU and GPU.
|
||||
:maxdepth: 1
|
||||
|
||||
python/array
|
||||
python/data_types
|
||||
python/devices_and_streams
|
||||
python/ops
|
||||
python/random
|
||||
python/transforms
|
||||
python/fast
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
@@ -80,3 +82,4 @@ are the CPU and GPU.
|
||||
:maxdepth: 1
|
||||
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
|
@@ -70,16 +70,13 @@ To build and install the MLX python library from source, first, clone MLX from
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
|
||||
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
|
||||
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install "pybind11[global]"
|
||||
conda install pybind11
|
||||
brew install pybind11
|
||||
pip install git+https://github.com/wjakob/nanobind.git
|
||||
|
||||
Then simply build and install it using pip:
|
||||
Then simply build and install MLX using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
@@ -158,6 +155,8 @@ should point to the path to the built metal library.
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
||||
* - MLX_METAL_DEBUG
|
||||
- OFF
|
||||
|
||||
|
||||
.. note::
|
||||
|
@@ -10,27 +10,38 @@ Array
|
||||
|
||||
array
|
||||
array.astype
|
||||
array.at
|
||||
array.item
|
||||
array.tolist
|
||||
array.dtype
|
||||
array.itemsize
|
||||
array.nbytes
|
||||
array.ndim
|
||||
array.shape
|
||||
array.size
|
||||
Dtype
|
||||
array.abs
|
||||
array.all
|
||||
array.any
|
||||
array.argmax
|
||||
array.argmin
|
||||
array.cos
|
||||
array.dtype
|
||||
array.cummax
|
||||
array.cummin
|
||||
array.cumprod
|
||||
array.cumsum
|
||||
array.diag
|
||||
array.diagonal
|
||||
array.exp
|
||||
array.flatten
|
||||
array.log
|
||||
array.log10
|
||||
array.log1p
|
||||
array.log2
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
array.min
|
||||
array.moveaxis
|
||||
array.prod
|
||||
array.reciprocal
|
||||
array.reshape
|
||||
@@ -40,6 +51,8 @@ Array
|
||||
array.split
|
||||
array.sqrt
|
||||
array.square
|
||||
array.squeeze
|
||||
array.swapaxes
|
||||
array.sum
|
||||
array.transpose
|
||||
array.T
|
||||
|
@@ -1,7 +1,5 @@
|
||||
.. _data_types:
|
||||
|
||||
:orphan:
|
||||
|
||||
Data Types
|
||||
==========
|
||||
|
||||
@@ -44,9 +42,27 @@ The default floating point type is ``float32`` and the default integer type is
|
||||
* - ``int64``
|
||||
- 8
|
||||
- 64-bit signed integer
|
||||
* - ``bfloat16``
|
||||
- 2
|
||||
- 16-bit brain float (e8, m7)
|
||||
* - ``float16``
|
||||
- 2
|
||||
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
|
||||
- 16-bit IEEE float (e5, m10)
|
||||
* - ``float32``
|
||||
- 4
|
||||
- 32-bit float
|
||||
* - ``complex64``
|
||||
- 8
|
||||
- 64-bit complex float
|
||||
|
||||
|
||||
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||
``dtype`` (or category) is a subtype of another category.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Dtype
|
||||
DtypeCategory
|
||||
issubdtype
|
||||
|
14
docs/src/python/fast.rst
Normal file
14
docs/src/python/fast.rst
Normal file
@@ -0,0 +1,14 @@
|
||||
.. _fast:
|
||||
|
||||
Fast
|
||||
====
|
||||
|
||||
.. currentmodule:: mlx.core.fast
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
rms_norm
|
||||
layer_norm
|
||||
rope
|
||||
scaled_dot_product_attention
|
@@ -30,6 +30,7 @@ Module
|
||||
Module.named_modules
|
||||
Module.parameters
|
||||
Module.save_weights
|
||||
Module.set_dtype
|
||||
Module.train
|
||||
Module.trainable_parameters
|
||||
Module.unfreeze
|
||||
|
@@ -38,6 +38,10 @@ Operations
|
||||
conv_general
|
||||
cos
|
||||
cosh
|
||||
cummax
|
||||
cummin
|
||||
cumprod
|
||||
cumsum
|
||||
dequantize
|
||||
diag
|
||||
diagonal
|
||||
@@ -58,10 +62,10 @@ Operations
|
||||
identity
|
||||
inner
|
||||
isclose
|
||||
isnan
|
||||
isposinf
|
||||
isneginf
|
||||
isinf
|
||||
isnan
|
||||
isneginf
|
||||
isposinf
|
||||
less
|
||||
less_equal
|
||||
linspace
|
||||
|
@@ -49,7 +49,7 @@ it will be added. You can load the array with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> mx.load("array.npy", a)
|
||||
>>> mx.load("array.npy")
|
||||
array([1], dtype=float32)
|
||||
|
||||
Here's an example of saving several arrays to a single file:
|
||||
|
@@ -8,3 +8,4 @@ endfunction(build_example)
|
||||
build_example(tutorial.cpp)
|
||||
build_example(linear_regression.cpp)
|
||||
build_example(logistic_regression.cpp)
|
||||
build_example(metal_capture.cpp)
|
||||
|
30
examples/cpp/metal_capture.cpp
Normal file
30
examples/cpp/metal_capture.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
// Enable the MLX_METAL_DEBUG CMake option to enhance the capture with groups,
|
||||
// labels, etc.
|
||||
assert(metal::start_capture());
|
||||
|
||||
// Start at index two because the default GPU and CPU streams have indices
|
||||
// zero and one, respectively. This naming matches the label assigned to each
|
||||
// stream's command queue.
|
||||
auto s2 = new_stream(Device::gpu);
|
||||
auto s3 = new_stream(Device::gpu);
|
||||
|
||||
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
||||
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
||||
auto x = add(a, a, s2);
|
||||
auto y = add(b, b, s3);
|
||||
|
||||
// The multiply will happen on the default stream.
|
||||
std::cout << multiply(x, y) << std::endl;
|
||||
|
||||
metal::stop_capture();
|
||||
}
|
@@ -61,7 +61,7 @@ array axpby(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
|
@@ -12,16 +12,6 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
||||
std::vector<size_t> strides(shape.size());
|
||||
size_t cum_prod = 1;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
strides[i] = cum_prod;
|
||||
cum_prod *= shape[i];
|
||||
}
|
||||
return {cum_prod, strides};
|
||||
}
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
@@ -36,22 +26,11 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
init(&cval);
|
||||
}
|
||||
|
||||
array::array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
shape,
|
||||
dtype,
|
||||
std::move(primitive),
|
||||
inputs)) {}
|
||||
|
||||
array::array(
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs)
|
||||
std::vector<array> inputs)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::move(shape),
|
||||
dtype,
|
||||
@@ -59,15 +38,16 @@ array::array(
|
||||
std::move(inputs))) {}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
std::vector<std::vector<int>> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<array> outputs;
|
||||
for (int i = 0; i < shapes.size(); ++i) {
|
||||
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
|
||||
}
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
// For each node in |outputs|, its siblings are the other nodes.
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
auto siblings = outputs;
|
||||
siblings.erase(siblings.begin() + i);
|
||||
outputs[i].set_siblings(std::move(siblings), i);
|
||||
@@ -92,10 +72,10 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
set_data(data, deleter);
|
||||
}
|
||||
|
||||
@@ -181,39 +161,33 @@ void array::move_shared_buffer(array other) {
|
||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
||||
: shape(shape), dtype(dtype) {
|
||||
std::tie(size, strides) = cum_prod(shape);
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs)
|
||||
: shape(shape),
|
||||
dtype(dtype),
|
||||
primitive(std::move(primitive)),
|
||||
inputs(inputs) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : this->inputs) {
|
||||
void array::ArrayDesc::init() {
|
||||
strides.resize(shape.size());
|
||||
size = 1;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
strides[i] = size;
|
||||
size *= shape[i];
|
||||
}
|
||||
for (auto& in : inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||
: shape(std::move(shape)), dtype(dtype) {
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
std::vector<int>&& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs)
|
||||
std::vector<array> inputs)
|
||||
: shape(std::move(shape)),
|
||||
dtype(dtype),
|
||||
primitive(std::move(primitive)),
|
||||
inputs(std::move(inputs)) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
}
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
|
69
mlx/array.h
69
mlx/array.h
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
@@ -31,7 +32,7 @@ class array {
|
||||
template <typename It>
|
||||
array(
|
||||
It data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype =
|
||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||
|
||||
@@ -47,13 +48,13 @@ class array {
|
||||
template <typename T>
|
||||
array(
|
||||
std::initializer_list<T> data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a buffer */
|
||||
array(
|
||||
allocator::Buffer data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter = allocator::free);
|
||||
|
||||
@@ -172,22 +173,16 @@ class array {
|
||||
* API may change.
|
||||
*/
|
||||
|
||||
array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
array(
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs);
|
||||
std::vector<array> inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
std::vector<std::vector<int>> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
/** A unique identifier for an array. */
|
||||
@@ -261,6 +256,17 @@ class array {
|
||||
array_desc_->position = position;
|
||||
}
|
||||
|
||||
/** The i-th output of the array's primitive. */
|
||||
const array& output(int i) const {
|
||||
if (i == array_desc_->position) {
|
||||
return *this;
|
||||
} else if (i < array_desc_->position) {
|
||||
return siblings()[i];
|
||||
} else {
|
||||
return siblings()[i + 1];
|
||||
}
|
||||
};
|
||||
|
||||
/** The outputs of the array's primitive (i.e. this array and
|
||||
* its siblings) in the order the primitive expects. */
|
||||
std::vector<array> outputs() const {
|
||||
@@ -362,7 +368,7 @@ class array {
|
||||
std::vector<size_t> strides;
|
||||
size_t size;
|
||||
Dtype dtype;
|
||||
std::shared_ptr<Primitive> primitive{nullptr};
|
||||
std::shared_ptr<Primitive> primitive;
|
||||
|
||||
// Indicates an array is being used in a graph transform
|
||||
// and should not be detached from the graph
|
||||
@@ -370,7 +376,7 @@ class array {
|
||||
|
||||
// This is a shared pointer so that *different* arrays
|
||||
// can share the underlying data buffer.
|
||||
std::shared_ptr<Data> data{nullptr};
|
||||
std::shared_ptr<Data> data;
|
||||
|
||||
// Properly offset data pointer
|
||||
void* data_ptr{nullptr};
|
||||
@@ -390,26 +396,24 @@ class array {
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
std::vector<array> inputs);
|
||||
|
||||
explicit ArrayDesc(
|
||||
std::vector<int>&& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs);
|
||||
private:
|
||||
// Initialize size, strides, and other metadata
|
||||
void init();
|
||||
};
|
||||
|
||||
// The ArrayDesc contains the details of the materialized array including the
|
||||
// shape, strides, the data type. It also includes
|
||||
// the primitive which knows how to compute the array's data from its inputs
|
||||
// and the list of array's inputs for the primitive.
|
||||
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
|
||||
std::shared_ptr<ArrayDesc> array_desc_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -421,9 +425,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||
template <typename It>
|
||||
array::array(
|
||||
It data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
||||
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
init(data);
|
||||
}
|
||||
|
||||
@@ -440,9 +444,9 @@ array::array(
|
||||
template <typename T>
|
||||
array::array(
|
||||
std::initializer_list<T> data,
|
||||
const std::vector<int>& shape,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
if (data.size() != size()) {
|
||||
throw std::invalid_argument(
|
||||
"Data size and provided shape mismatch in array construction.");
|
||||
@@ -517,4 +521,15 @@ void array::init(It src) {
|
||||
}
|
||||
}
|
||||
|
||||
/* Utilities for determining whether a template parameter is array. */
|
||||
template <typename T>
|
||||
inline constexpr bool is_array_v =
|
||||
std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
|
||||
|
||||
template <typename... T>
|
||||
inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
|
||||
|
||||
template <typename... T>
|
||||
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -69,11 +69,13 @@ DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
@@ -299,7 +301,7 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (is_floating_point(out.dtype())) {
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -353,7 +355,7 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto size = in.data_size();
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (is_floating_point(out.dtype())) {
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
@@ -201,7 +201,7 @@ struct NeonFp16SimdOps {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT, typename Ops, int N>
|
||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
|
||||
@@ -218,13 +218,21 @@ void softmax(const array& in, array& out) {
|
||||
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
|
||||
VT vals;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vals = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vals[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vmaximum = ops.max(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
T maximum = ops.reduce_max(vmaximum);
|
||||
AccT maximum = ops.reduce_max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, *current_in_ptr);
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
@@ -234,18 +242,29 @@ void softmax(const array& in, array& out) {
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
|
||||
ops.store(current_out_ptr, vexp);
|
||||
*(VT*)current_out_ptr = vexp;
|
||||
VT vexp;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vexp = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vexp = ops.exp(ops.sub(vexp, maximum));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = ops.add(vnormalizer, vexp);
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
T normalizer = ops.reduce_add(vnormalizer);
|
||||
AccT normalizer = ops.reduce_add(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
T _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = _exp;
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
@@ -254,14 +273,33 @@ void softmax(const array& in, array& out) {
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
} else {
|
||||
VT vexp;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
current_out_ptr[i] = vexp[i];
|
||||
}
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*current_out_ptr *= normalizer;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
@@ -308,15 +346,29 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
|
||||
in, out);
|
||||
softmax<
|
||||
float,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
softmax<
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
if (precise_) {
|
||||
softmax<
|
||||
float16_t,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
} else {
|
||||
softmax<
|
||||
float16_t,
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
eval(inputs, out);
|
||||
|
@@ -44,7 +44,6 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
@@ -54,6 +53,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
)
|
||||
|
||||
|
@@ -179,18 +179,16 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
} else {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
}
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[logaddexp] Cannot compute logaddexp for arrays with"
|
||||
|
@@ -126,4 +126,102 @@ std::string build_lib_name(
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape) {
|
||||
bool contiguous = true;
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
bool all_col_contig = true;
|
||||
int non_scalar_inputs = 0;
|
||||
for (const auto& x : inputs) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
non_scalar_inputs++;
|
||||
bool shape_eq = x.shape() == shape;
|
||||
all_contig &= (x.flags().contiguous && shape_eq);
|
||||
all_row_contig &= (x.flags().row_contiguous && shape_eq);
|
||||
all_col_contig &= (x.flags().col_contiguous && shape_eq);
|
||||
}
|
||||
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 1 && !all_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 0 && !shape.empty()) {
|
||||
contiguous = false;
|
||||
}
|
||||
return contiguous;
|
||||
}
|
||||
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers /* = false */) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Correct size
|
||||
// - Not a scalar
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o++].move_shared_buffer(in);
|
||||
} else {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
strides = in.strides();
|
||||
flags = in.flags();
|
||||
data_size = in.data_size();
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
}
|
||||
} else {
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
} else {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
}
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -53,4 +53,18 @@ inline bool is_scalar(const array& x) {
|
||||
return x.ndim() == 0;
|
||||
}
|
||||
|
||||
// Check if we can use a contiguous operation given inputs and the output shape
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape);
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers = false);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -52,8 +52,25 @@ void* compile(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string kernel_file_name;
|
||||
|
||||
// Deal with long kernel names. Maximum length for files on macOS is 255
|
||||
// characters. Clip file name with a little extra room and append a 16
|
||||
// character hash.
|
||||
constexpr int max_file_name_length = 245;
|
||||
if (kernel_name.size() > max_file_name_length) {
|
||||
std::ostringstream file_name;
|
||||
file_name
|
||||
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
|
||||
auto file_id = std::hash<std::string>{}(kernel_name);
|
||||
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
|
||||
kernel_file_name = file_name.str();
|
||||
} else {
|
||||
kernel_file_name = kernel_name;
|
||||
}
|
||||
|
||||
std::ostringstream shared_lib_name;
|
||||
shared_lib_name << "lib" << kernel_name << ".so";
|
||||
shared_lib_name << "lib" << kernel_file_name << ".so";
|
||||
auto shared_lib_path = get_temp_file(shared_lib_name.str());
|
||||
bool lib_exists = false;
|
||||
{
|
||||
@@ -64,7 +81,7 @@ void* compile(
|
||||
if (!lib_exists) {
|
||||
// Open source file and write source code to it
|
||||
std::ostringstream source_file_name;
|
||||
source_file_name << kernel_name << ".cpp";
|
||||
source_file_name << kernel_file_name << ".cpp";
|
||||
auto source_file_path = get_temp_file(source_file_name.str());
|
||||
|
||||
std::ofstream source_file(source_file_path);
|
||||
@@ -248,28 +265,7 @@ void Compiled::eval_cpu(
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
{
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
bool all_col_contig = true;
|
||||
int non_scalar_inputs = 0;
|
||||
for (auto& x : inputs) {
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
non_scalar_inputs++;
|
||||
bool shape_eq = x.shape() == shape;
|
||||
all_contig &= (x.flags().contiguous && shape_eq);
|
||||
all_row_contig &= (x.flags().row_contiguous && shape_eq);
|
||||
all_col_contig &= (x.flags().col_contiguous && shape_eq);
|
||||
}
|
||||
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
|
||||
contiguous = false;
|
||||
} else if (non_scalar_inputs == 1 && !all_contig) {
|
||||
contiguous = false;
|
||||
}
|
||||
}
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
std::vector<void*> args;
|
||||
@@ -342,58 +338,8 @@ void Compiled::eval_cpu(
|
||||
fn_ptr = compile(kernel_name, kernel.str());
|
||||
}
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
strides = in.strides();
|
||||
flags = in.flags();
|
||||
data_size = in.data_size();
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
}
|
||||
} else {
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
@@ -25,121 +25,196 @@ void copy_vector(const array& src, array& dst) {
|
||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_dim1(const array& src, array& dst) {
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim1(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
size_t src_idx = 0;
|
||||
size_t dst_idx = 0;
|
||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||
stride_t src_idx = i_offset;
|
||||
stride_t dst_idx = 0;
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += src.strides()[0];
|
||||
src_idx += i_strides[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_dim2(const array& src, array& dst) {
|
||||
inline void copy_general_dim1(const array& src, array& dst) {
|
||||
return copy_general_dim1<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim2(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
size_t src_idx = 0;
|
||||
size_t dst_idx = 0;
|
||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
||||
stride_t src_idx = i_offset;
|
||||
stride_t dst_idx = 0;
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
for (int j = 0; j < data_shape[1]; ++j) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += src.strides()[1];
|
||||
src_idx += i_strides[1];
|
||||
}
|
||||
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
||||
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_dim3(const array& src, array& dst) {
|
||||
inline void copy_general_dim2(const array& src, array& dst) {
|
||||
return copy_general_dim2<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim3(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
size_t src_idx = 0;
|
||||
size_t dst_idx = 0;
|
||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < src.shape()[2]; ++k) {
|
||||
stride_t src_idx = i_offset;
|
||||
stride_t dst_idx = 0;
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
for (int j = 0; j < data_shape[1]; ++j) {
|
||||
for (int k = 0; k < data_shape[2]; ++k) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += src.strides()[2];
|
||||
src_idx += i_strides[2];
|
||||
}
|
||||
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
|
||||
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
|
||||
}
|
||||
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
||||
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_dim4(const array& src, array& dst) {
|
||||
inline void copy_general_dim3(const array& src, array& dst) {
|
||||
return copy_general_dim3<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_dim4(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
const SrcT* src_ptr = src.data<SrcT>();
|
||||
DstT* dst_ptr = dst.data<DstT>();
|
||||
size_t src_idx = 0;
|
||||
size_t dst_idx = 0;
|
||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < src.shape()[2]; ++k) {
|
||||
for (size_t ii = 0; ii < src.shape()[3]; ++ii) {
|
||||
stride_t src_idx = i_offset;
|
||||
stride_t dst_idx = 0;
|
||||
for (int i = 0; i < data_shape[0]; ++i) {
|
||||
for (int j = 0; j < data_shape[1]; ++j) {
|
||||
for (int k = 0; k < data_shape[2]; ++k) {
|
||||
for (int ii = 0; ii < data_shape[3]; ++ii) {
|
||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||
src_idx += src.strides()[3];
|
||||
src_idx += i_strides[3];
|
||||
}
|
||||
src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3];
|
||||
src_idx += i_strides[2] - i_strides[3] * data_shape[3];
|
||||
}
|
||||
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
|
||||
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
|
||||
}
|
||||
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
||||
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general(const array& src, array& dst) {
|
||||
inline void copy_general_dim4(const array& src, array& dst) {
|
||||
return copy_general_dim4<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
switch (src.ndim()) {
|
||||
case 1:
|
||||
copy_general_dim1<SrcT, DstT>(src, dst);
|
||||
copy_general_dim1<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
case 2:
|
||||
copy_general_dim2<SrcT, DstT>(src, dst);
|
||||
copy_general_dim2<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
case 3:
|
||||
copy_general_dim3<SrcT, DstT>(src, dst);
|
||||
copy_general_dim3<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
case 4:
|
||||
copy_general_dim4<SrcT, DstT>(src, dst);
|
||||
copy_general_dim4<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
}
|
||||
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
for (size_t i = 0; i < dst.size(); ++i) {
|
||||
size_t src_elem = elem_to_loc(i, src.shape(), src.strides());
|
||||
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
|
||||
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, int D>
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general(const array& src, array& dst) {
|
||||
return copy_general<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
inline void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
return copy_general<SrcT, DstT, stride_t>(
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename stride_t, int D>
|
||||
inline void copy_general_general_dims(
|
||||
const array& src,
|
||||
array& dst,
|
||||
size_t offset_src,
|
||||
size_t offset_dst) {
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
stride_t i_offset,
|
||||
stride_t o_offset) {
|
||||
if constexpr (D > 1) {
|
||||
int axis = src.ndim() - D;
|
||||
auto stride_src = src.strides()[axis];
|
||||
auto stride_dst = dst.strides()[axis];
|
||||
auto N = src.shape(axis);
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
auto N = data_shape[axis];
|
||||
for (int i = 0; i < N; i++) {
|
||||
copy_general_general_dims<SrcT, DstT, D - 1>(
|
||||
src, dst, offset_src, offset_dst);
|
||||
offset_src += stride_src;
|
||||
offset_dst += stride_dst;
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
i_offset += stride_src;
|
||||
o_offset += stride_dst;
|
||||
}
|
||||
} else {
|
||||
int axis = src.ndim() - 1;
|
||||
auto stride_src = src.strides()[axis];
|
||||
auto stride_dst = dst.strides()[axis];
|
||||
auto N = src.shape(axis);
|
||||
const SrcT* src_ptr = src.data<SrcT>() + offset_src;
|
||||
DstT* dst_ptr = dst.data<DstT>() + offset_dst;
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
auto N = data_shape[axis];
|
||||
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
|
||||
DstT* dst_ptr = dst.data<DstT>() + o_offset;
|
||||
for (int i = 0; i < N; i++) {
|
||||
*dst_ptr = static_cast<DstT>(*src_ptr);
|
||||
src_ptr += stride_src;
|
||||
@@ -148,37 +223,56 @@ inline void copy_general_general_dims(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_general_general(const array& src, array& dst) {
|
||||
template <typename SrcT, typename DstT, typename stride_t>
|
||||
void copy_general_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
stride_t i_offset,
|
||||
stride_t o_offset) {
|
||||
switch (src.ndim()) {
|
||||
case 1:
|
||||
copy_general_general_dims<SrcT, DstT, 1>(src, dst, 0, 0);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 2:
|
||||
copy_general_general_dims<SrcT, DstT, 2>(src, dst, 0, 0);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 3:
|
||||
copy_general_general_dims<SrcT, DstT, 3>(src, dst, 0, 0);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 4:
|
||||
copy_general_general_dims<SrcT, DstT, 4>(src, dst, 0, 0);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 5:
|
||||
copy_general_general_dims<SrcT, DstT, 5>(src, dst, 0, 0);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
}
|
||||
|
||||
int size = std::accumulate(
|
||||
src.shape().begin() - 5, src.shape().end(), 1, std::multiplies<int>());
|
||||
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||
for (int i = 0; i < src.size(); i += size) {
|
||||
size_t offset_src = elem_to_loc(i, src.shape(), src.strides());
|
||||
size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides());
|
||||
copy_general_general_dims<SrcT, DstT, 5>(src, dst, offset_src, offset_dst);
|
||||
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
|
||||
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
inline void copy_general_general(const array& src, array& dst) {
|
||||
return copy_general_general<SrcT, DstT, size_t>(
|
||||
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
copy_single<SrcT, DstT>(src, dst);
|
||||
@@ -187,54 +281,103 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
||||
copy_vector<SrcT, DstT>(src, dst);
|
||||
return;
|
||||
case CopyType::General:
|
||||
copy_general<SrcT, DstT>(src, dst);
|
||||
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
return;
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_general_general<SrcT, DstT>(src, dst);
|
||||
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT>
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
template <typename SrcT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
switch (dst.dtype()) {
|
||||
case bool_:
|
||||
copy<SrcT, bool>(src, dst, ctype);
|
||||
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<SrcT, uint8_t>(src, dst, ctype);
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<SrcT, uint16_t>(src, dst, ctype);
|
||||
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<SrcT, uint32_t>(src, dst, ctype);
|
||||
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<SrcT, uint64_t>(src, dst, ctype);
|
||||
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<SrcT, int8_t>(src, dst, ctype);
|
||||
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<SrcT, int16_t>(src, dst, ctype);
|
||||
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<SrcT, int32_t>(src, dst, ctype);
|
||||
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<SrcT, int64_t>(src, dst, ctype);
|
||||
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<SrcT, float16_t>(src, dst, ctype);
|
||||
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<SrcT, float>(src, dst, ctype);
|
||||
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype);
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<SrcT, complex64_t>(src, dst, ctype);
|
||||
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
inline void copy_inplace_dispatch(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Args&&... args) {
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -242,47 +385,7 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
||||
} // namespace
|
||||
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
copy<bool>(src, dst, ctype);
|
||||
break;
|
||||
case uint8:
|
||||
copy<uint8_t>(src, dst, ctype);
|
||||
break;
|
||||
case uint16:
|
||||
copy<uint16_t>(src, dst, ctype);
|
||||
break;
|
||||
case uint32:
|
||||
copy<uint32_t>(src, dst, ctype);
|
||||
break;
|
||||
case uint64:
|
||||
copy<uint64_t>(src, dst, ctype);
|
||||
break;
|
||||
case int8:
|
||||
copy<int8_t>(src, dst, ctype);
|
||||
break;
|
||||
case int16:
|
||||
copy<int16_t>(src, dst, ctype);
|
||||
break;
|
||||
case int32:
|
||||
copy<int32_t>(src, dst, ctype);
|
||||
break;
|
||||
case int64:
|
||||
copy<int64_t>(src, dst, ctype);
|
||||
break;
|
||||
case float16:
|
||||
copy<float16_t>(src, dst, ctype);
|
||||
break;
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype);
|
||||
break;
|
||||
case complex64:
|
||||
copy<complex64_t>(src, dst, ctype);
|
||||
break;
|
||||
}
|
||||
return copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
@@ -312,4 +415,62 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
||||
copy_inplace(src, dst, ctype);
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype) {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
return copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset);
|
||||
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
return copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void copy_inplace<int64_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<int64_t>& i_strides,
|
||||
const std::vector<int64_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype) {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
return copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset);
|
||||
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
return copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -26,4 +26,15 @@ enum class CopyType {
|
||||
void copy(const array& src, array& dst, CopyType ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -94,6 +94,7 @@ DEFAULT(Sign)
|
||||
DEFAULT(Sin)
|
||||
DEFAULT(Sinh)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT(Softmax)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT_MULTI(Split)
|
||||
@@ -105,6 +106,7 @@ DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
|
||||
namespace {
|
||||
|
||||
|
104
mlx/backend/common/inverse.cpp
Normal file
104
mlx/backend/common/inverse.cpp
Normal file
@@ -0,0 +1,104 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void inverse_impl(const array& a, array& inv) {
|
||||
// Lapack uses the column-major convention. We take advantage of the following
|
||||
// identity to avoid transposing (see
|
||||
// https://math.stackexchange.com/a/340234):
|
||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||
|
||||
// The inverse is computed in place, so just copy the input to the output.
|
||||
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute LU factorization.
|
||||
sgetrf_(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
static const int lwork_query = -1;
|
||||
float workspace_size = 0;
|
||||
|
||||
// Compute workspace size.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ nullptr,
|
||||
/* work = */ &workspace_size,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch =
|
||||
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
||||
}
|
||||
inverse_impl(inputs[0], output);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
return {{linalg::inv(a, stream())}, {ax}};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -22,7 +22,7 @@ namespace mlx::core {
|
||||
void Abs::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (is_unsigned(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), unsignedinteger)) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
@@ -37,7 +37,7 @@ void Arange::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcCos());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -49,7 +49,7 @@ void ArcCos::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcCosh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -61,7 +61,7 @@ void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcSin());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -73,7 +73,7 @@ void ArcSin::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcSinh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -85,7 +85,7 @@ void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcTan());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -97,7 +97,7 @@ void ArcTan::eval(const std::vector<array>& inputs, array& out) {
|
||||
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::ArcTanh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -171,7 +171,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Ceil());
|
||||
} else {
|
||||
// No-op integer types
|
||||
@@ -211,7 +211,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Cos());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -223,7 +223,7 @@ void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Cosh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -350,7 +350,7 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Exp());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -362,7 +362,7 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Floor::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Floor());
|
||||
} else {
|
||||
// No-op integer types
|
||||
@@ -388,7 +388,7 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_fp(in, out, detail::Log());
|
||||
@@ -410,7 +410,7 @@ void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Log1p::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Log1p());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -597,7 +597,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Round());
|
||||
} else {
|
||||
// No-op integer types
|
||||
@@ -608,7 +608,7 @@ void Round::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sigmoid());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -630,7 +630,7 @@ void Sign::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sin());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -642,7 +642,7 @@ void Sin::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Sinh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -651,36 +651,33 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
auto strides = in.strides();
|
||||
auto flags = in.flags();
|
||||
size_t data_offset = 0;
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
strides[i] *= strides_[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
|
||||
copy_needed |= strides_[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void Slice::shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.row_contiguous = true;
|
||||
flags.col_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; ri >= 0; i++, ri--) {
|
||||
flags.col_contiguous &= strides[i] == f_stride || out.shape(i) == 1;
|
||||
flags.row_contiguous &= strides[ri] == b_stride || out.shape(ri) == 1;
|
||||
f_stride *= out.shape(i);
|
||||
b_stride *= out.shape(ri);
|
||||
if (strides[i] > 0) {
|
||||
data_size *= out.shape(i);
|
||||
}
|
||||
}
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
@@ -694,7 +691,87 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_inplace<int64_t>(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General);
|
||||
} else {
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, out);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
}
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if materialization is needed
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(out);
|
||||
|
||||
// Do copy
|
||||
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_inplace<int64_t>(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ upd_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
@@ -773,7 +850,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Tan());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
@@ -785,7 +862,7 @@ void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||
void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, detail::Tanh());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
|
@@ -6,8 +6,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum ReductionOpType {
|
||||
// Self-explanatory. Read everything and produce 1 output.
|
||||
ContiguousAllReduce,
|
||||
@@ -38,6 +36,21 @@ enum ReductionOpType {
|
||||
GeneralReduce
|
||||
};
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
// Should this be in utils?
|
||||
inline void nd_loop(
|
||||
@@ -110,19 +123,6 @@ struct DefaultContiguousReduce {
|
||||
}
|
||||
};
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
|
@@ -1,13 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void RoPE::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -222,7 +222,7 @@ void scan_dispatch(
|
||||
}
|
||||
case Scan::Min: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
|
||||
auto init = (is_floating_point(input.dtype()))
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
@@ -232,7 +232,7 @@ void scan_dispatch(
|
||||
}
|
||||
case Scan::Max: {
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||
auto init = (is_floating_point(input.dtype()))
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@@ -10,7 +10,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename AccT>
|
||||
void softmax(const array& in, array& out) {
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
@@ -22,26 +22,36 @@ void softmax(const array& in, array& out) {
|
||||
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
T maximum = *current_in_ptr;
|
||||
AccT maximum = *current_in_ptr;
|
||||
for (int j = 0; j < N; j++, current_in_ptr++) {
|
||||
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum;
|
||||
maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
|
||||
: maximum;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
T normalizer = 0;
|
||||
AccT normalizer = 0;
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
|
||||
T expv = std::exp(*current_in_ptr - maximum);
|
||||
AccT expv = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += expv;
|
||||
*current_out_ptr = expv;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = expv;
|
||||
}
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_in_ptr = in_ptr;
|
||||
current_out_ptr = out_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++) {
|
||||
*current_out_ptr *= normalizer;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
auto v = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(v * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -67,11 +77,15 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
};
|
||||
array in = check_input(std::move(inputs[0]));
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
@@ -87,13 +101,21 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float>(in, out);
|
||||
softmax<float, float>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
softmax<float16_t>(in, out);
|
||||
if (precise_) {
|
||||
softmax<float16_t, float>(in, out);
|
||||
} else {
|
||||
softmax<float16_t, float16_t>(in, out);
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
softmax<bfloat16_t>(in, out);
|
||||
if (precise_) {
|
||||
softmax<bfloat16_t, float>(in, out);
|
||||
} else {
|
||||
softmax<bfloat16_t, bfloat16_t>(in, out);
|
||||
}
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack_helper.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -49,8 +50,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not used
|
||||
// here but required by lapack).
|
||||
std::vector<int> iwork;
|
||||
iwork.resize(12 * K);
|
||||
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
@@ -82,7 +82,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* iwork = */ iwork.data(),
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
@@ -120,7 +120,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ iwork.data(),
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
@@ -145,4 +145,12 @@ void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -8,11 +8,12 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
template <typename stride_t>
|
||||
inline stride_t elem_to_loc(
|
||||
int elem,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
size_t loc = 0;
|
||||
const std::vector<stride_t>& strides) {
|
||||
stride_t loc = 0;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(elem, shape[i]);
|
||||
loc += q_and_r.rem * strides[i];
|
||||
@@ -34,10 +35,11 @@ inline size_t elem_to_loc(int elem, const array& a) {
|
||||
//
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
template <typename stride_t>
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>> strides) {
|
||||
const std::vector<std::vector<stride_t>> strides) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
@@ -45,7 +47,7 @@ collapse_contiguous_dims(
|
||||
to_collapse.push_back(0);
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
for (const std::vector<size_t>& st : strides) {
|
||||
for (const std::vector<stride_t>& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1]) {
|
||||
contiguous = false;
|
||||
}
|
||||
@@ -62,7 +64,7 @@ collapse_contiguous_dims(
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<size_t>> out_strides(strides.size());
|
||||
std::vector<std::vector<stride_t>> out_strides(strides.size());
|
||||
for (int i = 0; i < to_collapse.size(); i++) {
|
||||
int current_shape = shape[to_collapse[i]];
|
||||
while (to_collapse[++i] != -1) {
|
||||
@@ -70,7 +72,7 @@ collapse_contiguous_dims(
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<size_t>& st = strides[j];
|
||||
const std::vector<stride_t>& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
||||
}
|
||||
}
|
||||
@@ -87,11 +89,33 @@ collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
return collapse_contiguous_dims(xs[0].shape(), strides);
|
||||
}
|
||||
|
||||
template <typename... Arrays>
|
||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(Arrays... xs) {
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
||||
return collapse_contiguous_dims(
|
||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
inline auto check_contiguity(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<stride_t>& strides) {
|
||||
size_t data_size = 1;
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
bool is_row_contiguous = true;
|
||||
bool is_col_contiguous = true;
|
||||
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
data_size *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -33,6 +33,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
|
@@ -229,14 +229,7 @@ void Compiled::eval_gpu(
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
for (auto& x : inputs) {
|
||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
||||
!is_scalar(x)) {
|
||||
contiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool contiguous = compiled_check_contiguity(inputs, output_shape);
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
@@ -317,28 +310,8 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
{
|
||||
int o = 0;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
// Conditions for donation
|
||||
// - Row contiguous
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, true);
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
@@ -12,8 +12,15 @@ namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
if (in.dtype() == out.dtype()) {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
@@ -37,15 +44,22 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& strides_in_pre,
|
||||
const std::vector<stride_t>& strides_out_pre,
|
||||
int64_t inp_offset,
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(in, out);
|
||||
auto& strides_in = strides[0];
|
||||
auto& strides_out = strides[1];
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector{strides_in_pre, strides_out_pre});
|
||||
auto& strides_in_ = strides[0];
|
||||
auto& strides_out_ = strides[1];
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
std::ostringstream kname;
|
||||
@@ -72,39 +86,44 @@ void copy_gpu_inplace(
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
inp_offset *= size_of(in.dtype());
|
||||
out_offset *= size_of(out.dtype());
|
||||
|
||||
set_array_buffer(compute_encoder, donate_in ? out : in, inp_offset, 0);
|
||||
set_array_buffer(compute_encoder, out, out_offset, 1);
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
size_t ndim = shape.size();
|
||||
int ndim = shape.size();
|
||||
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3);
|
||||
}
|
||||
set_vector_bytes(compute_encoder, shape, ndim, 2);
|
||||
}
|
||||
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(
|
||||
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||
}
|
||||
|
||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
int rest = in.size() / (dim0 * dim1);
|
||||
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
data_size *= s;
|
||||
int rest = data_size / (dim0 * dim1);
|
||||
|
||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
||||
}
|
||||
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
@@ -120,4 +139,25 @@ void copy_gpu_inplace(
|
||||
}
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int64_t>& istride,
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,12 +7,34 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Generic copy inplace
|
||||
template <typename stride_t>
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype);
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& src,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<int64_t>& istride,
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -12,6 +12,7 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
@@ -20,9 +21,9 @@ namespace mlx::core::metal {
|
||||
namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
|
||||
static constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
auto load_device() {
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
@@ -145,6 +146,7 @@ void Device::new_queue(int index) {
|
||||
// We lock this as a critical section for safety
|
||||
const std::lock_guard<std::mutex> lock(mtx_);
|
||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||
debug_set_stream_queue_label(q, index);
|
||||
if (!q) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Failed to make new command queue.");
|
||||
|
@@ -16,7 +16,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
|
||||
} // namespace
|
||||
|
||||
|
@@ -23,6 +23,8 @@ set(
|
||||
"gemv"
|
||||
"quantized"
|
||||
"random"
|
||||
"rms_norm"
|
||||
"layer_norm"
|
||||
"rope"
|
||||
"scan"
|
||||
"scaled_dot_product_attention"
|
||||
@@ -35,11 +37,17 @@ set(
|
||||
)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
-gline-tables-only
|
||||
-frecord-sources)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal -Wall -Wextra
|
||||
-fno-fast-math
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
COMMAND xcrun -sdk macosx metal
|
||||
${METAL_FLAGS}
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
-o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS}
|
||||
OUTPUT ${TARGET}.air
|
||||
|
@@ -1,29 +1,29 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_s(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_v(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t& src_stride,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
@@ -31,61 +31,61 @@ template <typename T, typename U>
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[2],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[3],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_g_nd(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int src_shape[DIM],
|
||||
constant const size_t src_strides[DIM],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int* src_shape,
|
||||
constant const size_t* src_strides,
|
||||
constant const int& ndim,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd1(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t& src_stride,
|
||||
constant const size_t& dst_stride,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
@@ -94,10 +94,10 @@ template <typename T, typename U>
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[2],
|
||||
constant const size_t dst_strides[2],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
@@ -106,10 +106,10 @@ template <typename T, typename U>
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[3],
|
||||
constant const size_t dst_strides[3],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
@@ -118,11 +118,11 @@ template <typename T, typename U>
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_gg_nd(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int src_shape[DIM],
|
||||
constant const size_t src_strides[DIM],
|
||||
constant const size_t dst_strides[DIM],
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||
@@ -131,12 +131,12 @@ template <typename T, typename U, int DIM>
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int* src_shape,
|
||||
constant const size_t* src_strides,
|
||||
constant const size_t* dst_strides,
|
||||
constant const int& ndim,
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
||||
@@ -146,70 +146,70 @@ template <typename T, typename U>
|
||||
#define instantiate_copy(name, itype, otype, ctype) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void copy_##ctype<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void copy_g_nd<itype, otype, dims>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int src_shape[dims], \
|
||||
constant const size_t src_strides[dims], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_" #dims)]] \
|
||||
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int src_shape[dims], \
|
||||
constant const size_t src_strides[dims], \
|
||||
constant const size_t dst_strides[dims], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
|
||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void copy_g_nd1<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t& src_stride, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void copy_g_nd2<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[2], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void copy_g_nd3<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[3], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_1")]] \
|
||||
[[kernel]] void copy_gg_nd1<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t& src_stride, \
|
||||
constant const size_t& dst_stride, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_2")]] \
|
||||
[[kernel]] void copy_gg_nd2<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[2], \
|
||||
constant const size_t dst_strides[2], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint2 index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_3")]] \
|
||||
[[kernel]] void copy_gg_nd3<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[3], \
|
||||
constant const size_t dst_strides[3], \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]); \
|
||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||
@@ -218,21 +218,21 @@ template <typename T, typename U>
|
||||
#define instantiate_copy_g(name, itype, otype) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void copy_g<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int* src_shape, \
|
||||
constant const size_t* src_strides, \
|
||||
constant const int& ndim, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name)]] \
|
||||
[[kernel]] void copy_gg<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int* src_shape, \
|
||||
constant const size_t* src_strides, \
|
||||
constant const size_t* dst_strides, \
|
||||
constant const int& ndim, \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
|
@@ -14,3 +14,5 @@ static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
||||
static MTL_CONST constexpr int RMS_N_READS = 4;
|
||||
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
||||
|
553
mlx/backend/metal/kernels/layer_norm.metal
Normal file
553
mlx/backend/metal/kernels/layer_norm.metal
Normal file
@@ -0,0 +1,553 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void layer_norm_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* b,
|
||||
device T* out,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
constant uint& b_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float thread_x[N_READS];
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
b += b_stride * lid * N_READS;
|
||||
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumx += thread_x[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumx += thread_x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
|
||||
// Write the outputs
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void layer_norm_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* b,
|
||||
device T* out,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
constant uint& b_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
b += b_stride * lid * N_READS;
|
||||
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
sumx2 += xi * xi;
|
||||
sumx += xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
sumx2 += xi * xi;
|
||||
sumx += xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
|
||||
// Write the outputs
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (x[r + i] - mean) * normalizer;
|
||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = (x[r + i] - mean) * normalizer;
|
||||
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_layer_norm_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the computation and accumulators
|
||||
float thread_x[N_READS];
|
||||
float thread_w[N_READS];
|
||||
float thread_g[N_READS];
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float sumwg = 0;
|
||||
float sumwgx = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumwg[SIMD_SIZE];
|
||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meanwg[1];
|
||||
threadgroup float local_meanwgx[1];
|
||||
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
thread_g[i] = g[i];
|
||||
float wg = thread_w[i] * thread_g[i];
|
||||
sumx += thread_x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumwg += wg;
|
||||
sumwgx += wg * thread_x[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[i * w_stride];
|
||||
thread_g[i] = g[i];
|
||||
float wg = thread_w[i] * thread_g[i];
|
||||
sumx += thread_x[i];
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumwg += wg;
|
||||
sumwgx += wg * thread_x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumwg = simd_sum(sumwg);
|
||||
sumwgx = simd_sum(sumwgx);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumwg[simd_lane_id] = 0;
|
||||
local_sumwgx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumwg[simd_group_id] = sumwg;
|
||||
local_sumwgx[simd_group_id] = sumwgx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
local_meanwg[0] = sumwg / axis_size;
|
||||
local_meanwgx[0] = sumwgx / axis_size;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float meanwg = local_meanwg[0];
|
||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
||||
float normalizer2 = normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = (thread_x[i] - mean) * normalizer;
|
||||
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
|
||||
thread_x[i] * meanwgxc * normalizer2);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_layer_norm_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the accumulators
|
||||
float sumx = 0;
|
||||
float sumx2 = 0;
|
||||
float sumwg = 0;
|
||||
float sumwgx = 0;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup float local_sumx[SIMD_SIZE];
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumwg[SIMD_SIZE];
|
||||
threadgroup float local_sumwgx[SIMD_SIZE];
|
||||
threadgroup float local_mean[1];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meanwg[1];
|
||||
threadgroup float local_meanwgx[1];
|
||||
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
sumx += xi;
|
||||
sumx2 += xi * xi;
|
||||
sumwg += wg;
|
||||
sumwgx += wg * xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
float wg = wi * gi;
|
||||
sumx += xi;
|
||||
sumx2 += xi * xi;
|
||||
sumwg += wg;
|
||||
sumwgx += wg * xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sumx = simd_sum(sumx);
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumwg = simd_sum(sumwg);
|
||||
sumwgx = simd_sum(sumwgx);
|
||||
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx[simd_lane_id] = 0;
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumwg[simd_lane_id] = 0;
|
||||
local_sumwgx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx[simd_group_id] = sumx;
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumwg[simd_group_id] = sumwg;
|
||||
local_sumwgx[simd_group_id] = sumwgx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
sumx = simd_sum(local_sumx[simd_lane_id]);
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumwg = simd_sum(local_sumwg[simd_lane_id]);
|
||||
sumwgx = simd_sum(local_sumwgx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
float mean = sumx / axis_size;
|
||||
float variance = sumx2 / axis_size - mean * mean;
|
||||
|
||||
local_mean[0] = mean;
|
||||
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
|
||||
local_meanwg[0] = sumwg / axis_size;
|
||||
local_meanwgx[0] = sumwgx / axis_size;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = local_mean[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float meanwg = local_meanwg[0];
|
||||
float meanwgxc = local_meanwgx[0] - meanwg * mean;
|
||||
float normalizer2 = normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = (x[i + r] - mean) * normalizer;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
||||
xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = (x[i + r] - mean) * normalizer;
|
||||
float wi = w[(i + r) * w_stride];
|
||||
float gi = g[i + r];
|
||||
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
|
||||
xi * meanwgxc * normalizer2);
|
||||
gw[i + r] = static_cast<T>(gi * xi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_layer_norm_single_row(name, itype) \
|
||||
template [[host_name("layer_norm" #name)]] [[kernel]] void \
|
||||
layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm_looped(name, itype) \
|
||||
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
|
||||
layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* b, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
constant uint& b_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("vjp_layer_norm_looped" #name)]] [[kernel]] void \
|
||||
vjp_layer_norm_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gb, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_layer_norm(name, itype) \
|
||||
instantiate_layer_norm_single_row(name, itype) \
|
||||
instantiate_layer_norm_looped(name, itype)
|
||||
|
||||
instantiate_layer_norm(float32, float)
|
||||
instantiate_layer_norm(float16, half)
|
||||
instantiate_layer_norm(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
||||
|
@@ -15,14 +15,6 @@ using namespace metal;
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <typename T> struct AccT {
|
||||
typedef T acc_t;
|
||||
};
|
||||
|
||||
template <> struct AccT<bfloat16_t> {
|
||||
typedef float acc_t;
|
||||
};
|
||||
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector(const device T *x, thread U *x_thread) {
|
||||
@@ -60,6 +52,51 @@ inline U load_vector(const device T *x, thread U *x_thread) {
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U sum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 4.0f;
|
||||
x_thread[i+2] = x[i+2] / 16.0f;
|
||||
x_thread[i+3] = x[i+3] / 64.0f;
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 16.0f;
|
||||
x_thread[i+2] = x[i+2] / 256.0f;
|
||||
x_thread[i+3] = x[i+3] / 4096.0f;
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
sum += x[i];
|
||||
x_thread[i] = x[i];
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
@@ -96,6 +133,74 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum, int N) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U accum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (w[i] & 0x03)
|
||||
+ x_thread[4*i+1] * (w[i] & 0x0c)
|
||||
+ x_thread[4*i+2] * (w[i] & 0x30)
|
||||
+ x_thread[4*i+3] * (w[i] & 0xc0));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (ws[i] & 0x000f)
|
||||
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
|
||||
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
|
||||
+ x_thread[4*i+3] * (ws[i] & 0xf000));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
accum += x_thread[i] * w[i];
|
||||
}
|
||||
}
|
||||
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
if (bits == 2) {
|
||||
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
result[4*i] += x * (s[0] * (w[i] & 0x03) + bias);
|
||||
result[4*i+1] += x * (s[1] * (w[i] & 0x0c) + bias);
|
||||
result[4*i+2] += x * (s[2] * (w[i] & 0x30) + bias);
|
||||
result[4*i+3] += x * (s[3] * (w[i] & 0xc0) + bias);
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const thread uint16_t* ws = (const thread uint16_t*)w;
|
||||
U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f};
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
result[4*i] += x * (s[0] * (ws[i] & 0x000f) + bias);
|
||||
result[4*i+1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
|
||||
result[4*i+2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
|
||||
result[4*i+3] += x * (s[3] * (ws[i] & 0xf000) + bias);
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < values_per_thread; i++) {
|
||||
result[i] += x * (scale * w[i] + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
[[kernel]] void qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
@@ -204,7 +309,8 @@ template <typename T, const int group_size, const int bits>
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
int k = 0;
|
||||
for (; k < in_vec_size-block_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
@@ -222,6 +328,18 @@ template <typename T, const int group_size, const int bits>
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
@@ -239,7 +357,8 @@ template <typename T, const int group_size, const int bits>
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + used_out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
int k = 0;
|
||||
for (; k < in_vec_size-block_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
@@ -257,6 +376,18 @@ template <typename T, const int group_size, const int bits>
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot_safe<U, values_per_thread, bits>(wl, x_thread, s, b, sum, remaining);
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
@@ -268,7 +399,7 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
@@ -278,39 +409,28 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BM == SIMD_SIZE, "qvm expects BM to be equal to SIMD_SIZE");
|
||||
static_assert(BN == BM, "qvm expects a block size of 32x32");
|
||||
constexpr int num_simdgroups = 8;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int blocksize = SIMD_SIZE;
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_int = 32 / bits;
|
||||
constexpr int colgroup = BN * el_per_int;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
|
||||
typedef typename AccT<T>::acc_t U;
|
||||
threadgroup U scales_block[BM * groups_per_block];
|
||||
threadgroup U biases_block[BM * groups_per_block];
|
||||
threadgroup U x_block[BM];
|
||||
typedef float U;
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread U result[el_per_int] = {0};
|
||||
thread U result[pack_factor] = {0};
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_local = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||
const int out_vec_size_w = out_vec_size / pack_factor;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col_start = tid.y * (BN * el_per_int);
|
||||
int out_col = out_col_start + simd_gid * el_per_int;
|
||||
w += out_col / el_per_int;
|
||||
scales += out_col_start / group_size;
|
||||
biases += out_col_start / group_size;
|
||||
int out_col = tid.y * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
|
||||
w += out_col / pack_factor;
|
||||
scales += out_col / group_size;
|
||||
biases += out_col / group_size;
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size + out_col;
|
||||
|
||||
@@ -318,53 +438,39 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
return;
|
||||
}
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=BM) {
|
||||
int offset_lid = simd_lid + i;
|
||||
int offset_gid = simd_gid + i;
|
||||
bool thread_in_bounds = offset_lid < in_vec_size;
|
||||
bool group_in_bounds = offset_gid < in_vec_size;
|
||||
// Loop over in_vec in blocks of blocksize
|
||||
int i = 0;
|
||||
for (; i + blocksize <= in_vec_size; i += blocksize) {
|
||||
x_local = x[i + simd_lid];
|
||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
|
||||
}
|
||||
|
||||
// Load the scales and biases to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lid < groups_per_block && group_in_bounds) {
|
||||
scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
|
||||
biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load in_vec, scale, bias to registers
|
||||
x_local = x_block[simd_lid];
|
||||
scale = scales_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
|
||||
w_local >>= bits;
|
||||
}
|
||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
||||
}
|
||||
if (static_cast<int>(i + simd_lid) < in_vec_size) {
|
||||
x_local = x[i + simd_lid];
|
||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
} else {
|
||||
x_local = 0;
|
||||
scale = 0;
|
||||
bias = 0;
|
||||
w_local = 0;
|
||||
}
|
||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
for (int k=0; k<pack_factor; k++) {
|
||||
result[k] = simd_sum(result[k]);
|
||||
}
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
for (int k=0; k<pack_factor; k++) {
|
||||
y[k] = static_cast<T>(result[k]);
|
||||
}
|
||||
}
|
||||
@@ -414,6 +520,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
const int K_g = K / group_size;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
|
||||
x += y_row * K;
|
||||
w += y_col * K_w;
|
||||
scales += y_col * K_g;
|
||||
@@ -466,7 +573,10 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
|
||||
if (y_row + offset_row < N) {
|
||||
// y_col corresponds to the row of the weight matrix and added to
|
||||
// offset_row it should be less than the total number of rows
|
||||
// otherwise skip.
|
||||
if (y_col + offset_row < N) {
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
@@ -619,7 +729,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
||||
|
||||
if (y_row + offset_row < K) {
|
||||
if (k + offset_row < K) {
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
@@ -738,7 +848,7 @@ instantiate_qmv_types( 32, 8)
|
||||
|
||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qvm<itype, 32, 32, group_size, bits>( \
|
||||
[[kernel]] void qvm<itype, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
@@ -747,7 +857,6 @@ instantiate_qmv_types( 32, 8)
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
|
@@ -108,15 +108,17 @@ template <typename T, typename U, typename Op>
|
||||
const short i_ed = short(reduction_size);
|
||||
const short i_jump = reductions_per_thread;
|
||||
|
||||
for(short r = r_st; r < r_ed; r += r_jump) {
|
||||
if(r_st < r_jump) {
|
||||
for(short r = r_st; r < r_ed; r += r_jump) {
|
||||
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T * in_row = in + in_idx;
|
||||
|
||||
for(short i = i_st; i < i_ed; i += i_jump) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
|
||||
for(short i = i_st; i < i_ed; i += i_jump) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
435
mlx/backend/metal/kernels/rms_norm.metal
Normal file
435
mlx/backend/metal/kernels/rms_norm.metal
Normal file
@@ -0,0 +1,435 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void rms_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
device T* out,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||
threadgroup float* local_sums [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float acc = 0;
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i];
|
||||
acc += xi * xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i];
|
||||
acc += xi * xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
acc = simd_sum(acc);
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sums[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sums[simd_group_id] = acc;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
acc = simd_sum(local_sums[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write the outputs
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void rms_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
device T* out,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]],
|
||||
threadgroup float* local_sums [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
float acc = 0;
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
acc += xi * xi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
acc += xi * xi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
acc = simd_sum(acc);
|
||||
// Initialize shared memory
|
||||
if (simd_group_id == 0) {
|
||||
local_sums[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write simd accumulations into shared memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_sums[simd_group_id] = acc;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Accumulate over simd groups
|
||||
if (simd_group_id == 0) {
|
||||
acc = simd_sum(local_sums[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write the outputs
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[r + i] = w[w_stride * (i + r)] *
|
||||
static_cast<T>(x[r + i] * local_inv_mean[0]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
out[r + i] = w[w_stride * (i + r)] *
|
||||
static_cast<T>(x[r + i] * local_inv_mean[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_rms_single_row(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the computation and accumulators
|
||||
float thread_x[N_READS];
|
||||
float thread_w[N_READS];
|
||||
float thread_g[N_READS];
|
||||
float sumx2 = 0;
|
||||
float sumgwx = 0;
|
||||
|
||||
// Allocate shared memory to implement the reduction
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumgwx[SIMD_SIZE];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meangwx[1];
|
||||
|
||||
// Read and accumulate locally
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[w_stride * i];
|
||||
thread_g[i] = g[i];
|
||||
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
thread_x[i] = x[i];
|
||||
thread_w[i] = w[w_stride * i];
|
||||
thread_g[i] = g[i];
|
||||
|
||||
sumx2 += thread_x[i] * thread_x[i];
|
||||
sumgwx += thread_x[i] * thread_w[i] * thread_g[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate across threads
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumgwx = simd_sum(sumgwx);
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumgwx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumgwx[simd_group_id] = sumgwx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_meangwx[0] = sumgwx / axis_size;
|
||||
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float meangwx = local_meangwx[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
|
||||
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = RMS_N_READS>
|
||||
[[kernel]] void vjp_rms_looped(
|
||||
const device T* x,
|
||||
const device T* w,
|
||||
const device T* g,
|
||||
device T* gx,
|
||||
device T* gw,
|
||||
constant float& eps,
|
||||
constant uint& axis_size,
|
||||
constant uint& w_stride,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// Advance the input pointers
|
||||
x += gid * axis_size + lid * N_READS;
|
||||
g += gid * axis_size + lid * N_READS;
|
||||
w += w_stride * lid * N_READS;
|
||||
|
||||
// Allocate registers for the accumulators
|
||||
float sumx2 = 0;
|
||||
float sumgwx = 0;
|
||||
|
||||
// Allocate shared memory to implement the reduction
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
threadgroup float local_sumx2[SIMD_SIZE];
|
||||
threadgroup float local_sumgwx[SIMD_SIZE];
|
||||
threadgroup float local_normalizer[1];
|
||||
threadgroup float local_meangwx[1];
|
||||
|
||||
// Read and accumulate locally
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
sumx2 += xi * xi;
|
||||
sumgwx += xi * wi * gi;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
sumx2 += xi * xi;
|
||||
sumgwx += xi * wi * gi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate across threads
|
||||
sumx2 = simd_sum(sumx2);
|
||||
sumgwx = simd_sum(sumgwx);
|
||||
if (simd_group_id == 0) {
|
||||
local_sumx2[simd_lane_id] = 0;
|
||||
local_sumgwx[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lane_id == 0) {
|
||||
local_sumx2[simd_group_id] = sumx2;
|
||||
local_sumgwx[simd_group_id] = sumgwx;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id == 0) {
|
||||
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
|
||||
sumgwx = simd_sum(local_sumgwx[simd_lane_id]);
|
||||
if (simd_lane_id == 0) {
|
||||
local_meangwx[0] = sumgwx / axis_size;
|
||||
local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float meangwx = local_meangwx[0];
|
||||
float normalizer = local_normalizer[0];
|
||||
float normalizer3 = normalizer * normalizer * normalizer;
|
||||
|
||||
// Write the outputs
|
||||
gx += gid * axis_size + lid * N_READS;
|
||||
gw += gid * axis_size + lid * N_READS;
|
||||
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
|
||||
if (r + lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((r + lid * N_READS + i) < axis_size) {
|
||||
float xi = x[i + r];
|
||||
float wi = w[w_stride * (i + r)];
|
||||
float gi = g[i + r];
|
||||
|
||||
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
|
||||
gw[i + r] = static_cast<T>(gi * xi * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rms_single_row(name, itype) \
|
||||
template [[host_name("rms" #name)]] [[kernel]] void \
|
||||
rms_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms" #name)]] [[kernel]] void \
|
||||
vjp_rms_single_row<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms_looped(name, itype) \
|
||||
template [[host_name("rms_looped" #name)]] [[kernel]] void \
|
||||
rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
device itype* out, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
threadgroup float* local_inv_mean [[threadgroup(0)]], \
|
||||
threadgroup float* local_sums [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
\
|
||||
template [[host_name("vjp_rms_looped" #name)]] [[kernel]] void \
|
||||
vjp_rms_looped<itype>( \
|
||||
const device itype* x, \
|
||||
const device itype* w, \
|
||||
const device itype* g, \
|
||||
device itype* gx, \
|
||||
device itype* gw, \
|
||||
constant float& eps, \
|
||||
constant uint& axis_size, \
|
||||
constant uint& w_stride, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_rms(name, itype) \
|
||||
instantiate_rms_single_row(name, itype) \
|
||||
instantiate_rms_looped(name, itype)
|
||||
|
||||
instantiate_rms(float32, float)
|
||||
instantiate_rms(float16, half)
|
||||
instantiate_rms(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
@@ -5,11 +5,12 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional>
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope(
|
||||
const device T *in [[buffer(0)]],
|
||||
device T * out [[buffer(1)]],
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
constant const float& scale,
|
||||
@@ -19,13 +20,13 @@ template <typename T, bool traditional>
|
||||
uint in_index_1, in_index_2;
|
||||
uint out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
|
||||
out_index_2 = out_index_1 + grid.x;
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
|
||||
out_index_2 = out_index_1 + grid.x * out_strides[2];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
}
|
||||
@@ -42,27 +43,41 @@ template <typename T, bool traditional>
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1 = x1 * costheta - x2 * sintheta;
|
||||
float rx2 = x1 * sintheta + x2 * costheta;
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
rx1 = x1 * costheta - x2 * sintheta;
|
||||
rx2 = x1 * sintheta + x2 * costheta;
|
||||
} else {
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
#define instantiate_rope(name, type, traditional) \
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] \
|
||||
[[kernel]] void rope<type, traditional>( \
|
||||
[[kernel]] void rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
instantiate_rope(traditional_float16, half, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
|
||||
instantiate_rope(traditional_float32, float, true)
|
||||
instantiate_rope(float16, half, false)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false)
|
||||
instantiate_rope(float32, float, false)
|
||||
instantiate_rope(traditional_float16, half, true, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||
instantiate_rope(traditional_float32, float, true, true)
|
||||
instantiate_rope(float16, half, false, true)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false, true)
|
||||
instantiate_rope(float32, float, false, true)
|
||||
instantiate_rope(vjp_traditional_float16, half, true, false)
|
||||
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
|
||||
instantiate_rope(vjp_traditional_float32, float, true, false)
|
||||
instantiate_rope(vjp_float16, half, false, false)
|
||||
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
|
||||
instantiate_rope(vjp_float32, float, false, false)
|
||||
|
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_common>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
@@ -12,46 +11,48 @@ using namespace metal;
|
||||
|
||||
template <typename T>
|
||||
inline T softmax_exp(T x) {
|
||||
// Softmax doesn't need high precision exponential cause it is gonna be x
|
||||
// will be in (-oo, 0] anyway and subsequently it will be divided by
|
||||
// sum(exp(x_i)).
|
||||
// Softmax doesn't need high precision exponential cause x is gonna be in
|
||||
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
||||
return fast::exp(x);
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_single_row(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
int lid = _lid;
|
||||
|
||||
T ld[N_READS];
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
AccT ld[N_READS];
|
||||
|
||||
in += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
ld[i] = in[i];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = AccT(in[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] =
|
||||
((lid * N_READS + i) < axis_size) ? in[i] : T(Limits<T>::finite_min);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
if (simd_group_id == 0) {
|
||||
local_max[simd_lane_id] = Limits<T>::finite_min;
|
||||
local_max[simd_lane_id] = Limits<AccT>::finite_min;
|
||||
local_normalizer[simd_lane_id] = 0;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Get the max
|
||||
T maxval = Limits<T>::finite_min;
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
||||
}
|
||||
@@ -70,9 +71,9 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
maxval = local_max[0];
|
||||
|
||||
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
||||
T normalizer = 0;
|
||||
AccT normalizer = 0;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
T exp_x = softmax_exp(ld[i] - maxval);
|
||||
AccT exp_x = softmax_exp(ld[i] - maxval);
|
||||
ld[i] = exp_x;
|
||||
normalizer += exp_x;
|
||||
}
|
||||
@@ -93,25 +94,23 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
// Normalize and write to the output
|
||||
out += gid * axis_size + lid * N_READS;
|
||||
if (lid * N_READS + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = ld[i] * normalizer;
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if ((lid * N_READS + i) < axis_size) {
|
||||
out[i] = T(ld[i] * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
||||
[[kernel]] void softmax_looped(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant int& axis_size,
|
||||
threadgroup T* local_max [[threadgroup(0)]],
|
||||
threadgroup T* local_normalizer [[threadgroup(1)]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
@@ -119,22 +118,27 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
in += gid * axis_size;
|
||||
|
||||
constexpr int SIMD_SIZE = 32;
|
||||
|
||||
threadgroup AccT local_max[SIMD_SIZE];
|
||||
threadgroup AccT local_normalizer[SIMD_SIZE];
|
||||
|
||||
// Get the max and the normalizer in one go
|
||||
T prevmax;
|
||||
T maxval = Limits<T>::finite_min;
|
||||
T normalizer = 0;
|
||||
AccT prevmax;
|
||||
AccT maxval = Limits<AccT>::finite_min;
|
||||
AccT normalizer = 0;
|
||||
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
T vals[N_READS];
|
||||
AccT vals[N_READS];
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[offset + i];
|
||||
vals[i] = AccT(in[offset + i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] =
|
||||
(offset + i < axis_size) ? in[offset + i] : T(Limits<T>::finite_min);
|
||||
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
|
||||
: Limits<AccT>::finite_min;
|
||||
}
|
||||
}
|
||||
prevmax = maxval;
|
||||
@@ -180,49 +184,66 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
|
||||
r++) {
|
||||
int offset = r * lsize * N_READS + lid * N_READS;
|
||||
if (offset + N_READS <= axis_size) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
if (offset + i < axis_size) {
|
||||
out[offset + i] = softmax_exp(in[offset + i] - maxval) * normalizer;
|
||||
out[offset + i] =
|
||||
T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_softmax_single_row(name, itype) \
|
||||
// clang-format off
|
||||
#define instantiate_softmax(name, itype) \
|
||||
template [[host_name("softmax_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax_looped(name, itype) \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
threadgroup itype* local_max [[threadgroup(0)]], \
|
||||
threadgroup itype* local_normalizer [[threadgroup(1)]], \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_softmax(name, itype) \
|
||||
instantiate_softmax_single_row(name, itype) \
|
||||
instantiate_softmax_looped(name, itype)
|
||||
#define instantiate_softmax_precise(name, itype) \
|
||||
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
|
||||
softmax_single_row<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint _lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
|
||||
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
|
||||
softmax_looped<itype, float>( \
|
||||
const device itype* in, \
|
||||
device itype* out, \
|
||||
constant int& axis_size, \
|
||||
uint gid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
||||
instantiate_softmax(float32, float)
|
||||
instantiate_softmax(float16, half)
|
||||
instantiate_softmax(bfloat16, bfloat16_t)
|
||||
instantiate_softmax_precise(float16, half)
|
||||
instantiate_softmax_precise(bfloat16, bfloat16_t)
|
||||
// clang-format on
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -65,12 +65,18 @@ struct Limits<bool> {
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with generic dims
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides,
|
||||
device const stride_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
stride_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
@@ -78,12 +84,13 @@ inline size_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc(
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* strides,
|
||||
constant const stride_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
stride_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
@@ -91,52 +98,59 @@ inline size_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint3 elem_to_loc_3_nd(
|
||||
// Non templated version to handle arbitrary dims
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM],
|
||||
constant const size_t c_strides[NDIM]) {
|
||||
uint3 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
constant const int* shape,
|
||||
constant const stride_t* strides,
|
||||
int ndim) {
|
||||
stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Single Array with fixed N dims
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
|
||||
return elem * stride;
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t
|
||||
elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
|
||||
return elem.x * strides[1] + elem.y * strides[0];
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t
|
||||
elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
|
||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM]) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
METAL_FUNC size_t elem_to_loc_nd(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int d = NDIM - 2; d >= 0; --d) {
|
||||
elem /= shape[d + 1];
|
||||
loc += (elem % shape[d]) * strides[d];
|
||||
}
|
||||
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline size_t elem_to_loc_nd(
|
||||
METAL_FUNC size_t elem_to_loc_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t strides[NDIM]) {
|
||||
@@ -148,33 +162,59 @@ inline size_t elem_to_loc_nd(
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
||||
return elem * stride;
|
||||
template <int NDIM>
|
||||
METAL_FUNC int64_t elem_to_loc_nd(
|
||||
uint elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const int64_t strides[NDIM]) {
|
||||
int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
|
||||
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int d = NDIM - 2; d >= 0; --d) {
|
||||
elem /= shape[d + 1];
|
||||
loc += (elem % shape[d]) * strides[d];
|
||||
}
|
||||
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
||||
return elem.x * strides[1] + elem.y * strides[0];
|
||||
}
|
||||
|
||||
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
inline size_t elem_to_loc(
|
||||
template <int NDIM>
|
||||
METAL_FUNC int64_t elem_to_loc_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
constant const int shape[NDIM],
|
||||
constant const int64_t strides[NDIM]) {
|
||||
int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
loc += (elem.z % shape[d]) * strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint3 elem_to_loc_3_nd(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multiple Arrays with generic dims
|
||||
|
||||
METAL_FUNC uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
int ndim) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
METAL_FUNC uint3 elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
@@ -198,18 +238,21 @@ inline uint3 elem_to_loc_3_nd(
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Multiple Arrays with fixed N dims
|
||||
|
||||
template <int NDIM>
|
||||
METAL_FUNC uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
int ndim) {
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM]) {
|
||||
uint2 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
@@ -219,55 +262,26 @@ inline uint2 elem_to_loc_2_nd(
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint elem_to_loc_nd(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides);
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<1>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
return (elem % shape[0]) * strides[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<2>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<3>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline uint elem_to_loc_nd<4>(
|
||||
uint elem,
|
||||
device const int* shape,
|
||||
device const size_t* strides) {
|
||||
uint loc = (elem % shape[3]) * strides[3];
|
||||
elem /= shape[3];
|
||||
loc += (elem % shape[2]) * strides[2];
|
||||
elem /= shape[2];
|
||||
loc += (elem % shape[1]) * strides[1];
|
||||
elem /= shape[1];
|
||||
loc += (elem % shape[0]) * strides[0];
|
||||
METAL_FUNC uint3 elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM],
|
||||
constant const size_t c_strides[NDIM]) {
|
||||
uint3 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
|
@@ -206,7 +206,7 @@ inline auto collapse_batches(const array& a, const array& b) {
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride});
|
||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
@@ -237,8 +237,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride, C_bstride});
|
||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
@@ -488,7 +488,7 @@ void steel_matmul(
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
@@ -696,7 +696,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
@@ -15,6 +16,9 @@ bool is_available() {
|
||||
}
|
||||
|
||||
int max_ops_per_buffer() {
|
||||
#ifdef MLX_METAL_DEBUG
|
||||
return 1;
|
||||
#else
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
|
||||
return atoi(buff_str);
|
||||
@@ -24,6 +28,7 @@ int max_ops_per_buffer() {
|
||||
};
|
||||
static int max_ops_per_buffer_ = get_val();
|
||||
return max_ops_per_buffer_;
|
||||
#endif
|
||||
}
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
@@ -74,6 +79,8 @@ std::function<void()> make_task(
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
@@ -108,4 +115,31 @@ std::function<void()> make_task(
|
||||
return task;
|
||||
}
|
||||
|
||||
bool start_capture(std::string path, id object) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
|
||||
descriptor->setCaptureObject(object);
|
||||
|
||||
if (path.length() > 0) {
|
||||
auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
|
||||
auto url = NS::URL::fileURLWithPath(string);
|
||||
descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
|
||||
descriptor->setOutputURL(url);
|
||||
}
|
||||
|
||||
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||
return manager->startCapture(descriptor, nullptr);
|
||||
}
|
||||
|
||||
bool start_capture(std::string path) {
|
||||
auto& device = metal::device(mlx::core::Device::gpu);
|
||||
return start_capture(path, device.mtl_device());
|
||||
}
|
||||
|
||||
void stop_capture() {
|
||||
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||
manager->stopCapture();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -66,4 +66,8 @@ std::function<void()> make_task(
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||
bool start_capture(std::string path = "");
|
||||
void stop_capture();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
420
mlx/backend/metal/normalization.cpp
Normal file
420
mlx/backend/metal/normalization.cpp
Normal file
@@ -0,0 +1,420 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void RMSNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
|
||||
if (x.is_donatable()) {
|
||||
out.move_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
std::string op_name = "rms";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5);
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
16 * 8, 0); // minimum of 16 bytes
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void RMSNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
}
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& g = check_input(inputs[2]);
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
|
||||
// Allocate space for the outputs
|
||||
bool x_in_gx = false;
|
||||
bool g_in_gx = false;
|
||||
if (x.is_donatable()) {
|
||||
gx.move_shared_buffer(x);
|
||||
x_in_gx = true;
|
||||
} else if (g.is_donatable()) {
|
||||
gx.move_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and initialize the
|
||||
// gradient accumulator to 0.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
{
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copies.push_back(std::move(zero));
|
||||
}
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
std::string op_name = "vjp_rms";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(gx);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = w.strides()[0];
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(
|
||||
compute_encoder, g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
set_array_buffer(compute_encoder, gx, 3);
|
||||
set_array_buffer(compute_encoder, gw_temp, 4);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void LayerNorm::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
|
||||
if (x.is_donatable()) {
|
||||
out.move_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
std::string op_name = "layer_norm";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(
|
||||
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(compute_encoder, b, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
|
||||
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void LayerNormVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
}
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
const array& g = check_input(inputs[3]);
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
array& gb = outputs[2];
|
||||
|
||||
// Allocate space for the outputs
|
||||
bool x_in_gx = false;
|
||||
bool g_in_gx = false;
|
||||
if (x.is_donatable()) {
|
||||
gx.move_shared_buffer(x);
|
||||
x_in_gx = true;
|
||||
} else if (g.is_donatable()) {
|
||||
gx.move_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and initialize the
|
||||
// gradient accumulator to 0.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
{
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copy_gpu(zero, gb, CopyType::Scalar, s);
|
||||
copies.push_back(std::move(zero));
|
||||
}
|
||||
|
||||
// Finish with the gradient for b in case we had a b
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
if (gb.ndim() == 1 && gb.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
g_in_gx ? gx : (g_in_gw ? gw_temp : g),
|
||||
gb,
|
||||
"sum",
|
||||
plan,
|
||||
{0},
|
||||
compute_encoder,
|
||||
d,
|
||||
s);
|
||||
}
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
const int looped_limit = RMS_LOOPED_LIMIT;
|
||||
std::string op_name = "vjp_layer_norm";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "_looped";
|
||||
}
|
||||
op_name += type_to_name(gx);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, x_in_gx ? gx : x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
set_array_buffer(
|
||||
compute_encoder, g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
set_array_buffer(compute_encoder, gx, 3);
|
||||
set_array_buffer(compute_encoder, gw_temp, 4);
|
||||
compute_encoder->setBytes(&eps_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 7);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
if (gw.ndim() == 1 && gw.size() == axis_size) {
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -17,7 +17,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
@@ -822,7 +822,7 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (not is_integral(in.dtype())) {
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op(inputs, out, "round");
|
||||
} else {
|
||||
// No-op integer types
|
||||
@@ -865,7 +865,73 @@ void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
assert(inputs.size() == 1);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_gpu_inplace(
|
||||
/* const array& in = */ in,
|
||||
/* array& out = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General,
|
||||
/* const Stream& s = */ stream());
|
||||
} else {
|
||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, out);
|
||||
}
|
||||
}
|
||||
|
||||
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if materialization is needed
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] = prepare_slice(out);
|
||||
|
||||
// Do copy
|
||||
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||
copy_gpu_inplace<int64_t>(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ upd_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ out_strides,
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ data_offset,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ stream());
|
||||
}
|
||||
|
||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -900,4 +966,8 @@ void SVD::eval_gpu(
|
||||
throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI.");
|
||||
}
|
||||
|
||||
void Inverse::eval_gpu(const std::vector<array>& inputs, array& output) {
|
||||
throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -137,7 +137,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = std::min(32, O);
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
|
@@ -4,10 +4,10 @@
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
@@ -18,8 +18,6 @@ namespace mlx::core {
|
||||
// Case wise reduce dispatch
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
inline auto safe_div(size_t n, size_t m) {
|
||||
return m == 0 ? 0 : (n + m - 1) / m;
|
||||
}
|
||||
@@ -534,8 +532,6 @@ void strided_reduce_general_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// Main reduce dispatch
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
39
mlx/backend/metal/reduce.h
Normal file
39
mlx/backend/metal/reduce.h
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright @ 2023 - 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void all_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
void row_reduce_general_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
void strided_reduce_general_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,5 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
@@ -13,39 +13,74 @@ void RoPE::eval_gpu(
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() != 3) {
|
||||
throw std::runtime_error(
|
||||
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
|
||||
}
|
||||
if (dims_ != in.shape(-1)) {
|
||||
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
|
||||
}
|
||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t strides[3];
|
||||
size_t out_strides[3];
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
if (dims_ < in.shape(-1)) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
copy_gpu(in, out, ctype, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
} else if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
donated = true;
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
strides[0] = mat_size;
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (ndim == 3) {
|
||||
// Handle non-contiguous 3D inputs
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
strides[0] = in.strides()[0];
|
||||
strides[1] = in.strides()[1];
|
||||
strides[2] = in.strides()[2];
|
||||
} else {
|
||||
// Copy non-contiguous > 3D inputs into the output and treat
|
||||
// input as donated
|
||||
donated = true;
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
strides[0] = mat_size;
|
||||
strides[1] = out.strides()[ndim - 2];
|
||||
strides[2] = out.strides()[ndim - 1];
|
||||
}
|
||||
out_strides[0] = mat_size;
|
||||
out_strides[1] = out.strides()[ndim - 2];
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
kname << "rope_" << (forward_ ? "" : "vjp_")
|
||||
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
float base = std::log2(base_);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, donated ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&base, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&base, sizeof(float), 5);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 6);
|
||||
|
||||
int dim0 = in.shape(2) / 2;
|
||||
int dim1 = in.shape(1);
|
||||
int dim2 = in.shape(0);
|
||||
int dim0 = dims_ / 2;
|
||||
int dim1 = in.shape(-2);
|
||||
int dim2 = in.size() / mat_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
@@ -127,7 +127,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
assert(inputs.size() >= 3);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
|
||||
}
|
||||
|
@@ -12,7 +12,7 @@ namespace mlx::core {
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[softmax] Does not support non-floating point types.");
|
||||
}
|
||||
@@ -21,7 +21,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
@@ -30,18 +30,21 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
const array& in = check_input(inputs[0]);
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
@@ -53,6 +56,9 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "looped_";
|
||||
}
|
||||
if (in.dtype() != float32 && precise_) {
|
||||
op_name += "precise_";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
@@ -75,11 +81,10 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(
|
||||
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@@ -102,6 +102,11 @@ void multi_block_sort(
|
||||
|
||||
int nc_dim = nc_shape.size();
|
||||
|
||||
if (nc_dim == 0) {
|
||||
nc_shape = {0};
|
||||
nc_str = {1};
|
||||
}
|
||||
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
int stride_sorted_axis = in.strides()[axis];
|
||||
|
||||
@@ -143,8 +148,9 @@ void multi_block_sort(
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(
|
||||
nc_shape.data(), nc_shape.size() * sizeof(int), 6);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
@@ -158,7 +164,8 @@ void multi_block_sort(
|
||||
array dev_idxs_in = dev_idxs_0;
|
||||
array dev_vals_out = dev_vals_1;
|
||||
array dev_idxs_out = dev_idxs_1;
|
||||
for (int merge_tiles = 2; merge_tiles <= n_blocks; merge_tiles *= 2) {
|
||||
|
||||
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
|
||||
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
|
||||
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
|
||||
dev_vals_out = ping ? dev_vals_0 : dev_vals_1;
|
||||
|
@@ -4,21 +4,49 @@
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
int idx) {
|
||||
inline void
|
||||
set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
}
|
||||
|
||||
inline void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
int64_t offset,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const std::vector<T>& vec,
|
||||
size_t nelems,
|
||||
int idx) {
|
||||
enc->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void set_vector_bytes(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const std::vector<T>& vec,
|
||||
int idx) {
|
||||
return set_vector_bytes(enc, vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
std::string type_to_name(const array& a) {
|
||||
std::string tname;
|
||||
switch (a.dtype()) {
|
||||
@@ -96,6 +124,29 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||
}
|
||||
|
||||
inline NS::String* make_string(std::ostringstream& os) {
|
||||
std::string string = os.str();
|
||||
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
|
||||
}
|
||||
|
||||
inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
|
||||
#ifdef MLX_METAL_DEBUG
|
||||
std::ostringstream label;
|
||||
label << "Stream " << index;
|
||||
queue->setLabel(make_string(label));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void debug_set_primitive_buffer_label(
|
||||
MTL::CommandBuffer* command_buffer,
|
||||
Primitive& primitive) {
|
||||
#ifdef MLX_METAL_DEBUG
|
||||
std::ostringstream label;
|
||||
primitive.print(label);
|
||||
command_buffer->setLabel(make_string(label));
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -39,5 +39,9 @@ size_t set_memory_limit(size_t, bool) {
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
bool start_capture(std::string path) {
|
||||
return false;
|
||||
}
|
||||
void stop_capture() {}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -87,6 +87,7 @@ NO_GPU(Sign)
|
||||
NO_GPU(Sin)
|
||||
NO_GPU(Sinh)
|
||||
NO_GPU(Slice)
|
||||
NO_GPU(SliceUpdate)
|
||||
NO_GPU(Softmax)
|
||||
NO_GPU(Sort)
|
||||
NO_GPU_MULTI(Split)
|
||||
@@ -98,8 +99,13 @@ NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Transpose)
|
||||
NO_GPU(Inverse)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(LayerNorm)
|
||||
NO_GPU_MULTI(LayerNormVJP)
|
||||
NO_GPU_MULTI(RMSNorm)
|
||||
NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
} // namespace fast
|
||||
|
107
mlx/compile.cpp
107
mlx/compile.cpp
@@ -162,44 +162,51 @@ CompileMode& compile_mode() {
|
||||
return compile_mode_;
|
||||
}
|
||||
|
||||
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
||||
// Helper like below but only merges the two provided arrays. If the src has
|
||||
// siblings then these won't be merged to the dst.
|
||||
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
|
||||
auto src_parents = parents_map.find(src.id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
return;
|
||||
}
|
||||
auto& pairs = parents_map[dst.id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dst;
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
};
|
||||
|
||||
// Helper that merges two arrays in the graph by setting the parents of the
|
||||
// source to point to the destination
|
||||
// source to point to the destination. The arrays are assumed to be coming from
|
||||
// equivalent primitives so their siblings are merged as well.
|
||||
void merge(array& dst, array& src, ParentsMap& parents_map) {
|
||||
// Canonicalize the order of the primitives outputs
|
||||
auto sources = src.outputs();
|
||||
auto dests = dst.outputs();
|
||||
// For each src parent, point it to the corresponding dst
|
||||
for (int i = 0; i < sources.size(); ++i) {
|
||||
auto src_parents = parents_map.find(sources[i].id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& pairs = parents_map[dests[i].id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dests[i];
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
merge_one(dests[i], sources[i], parents_map);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename... U>
|
||||
size_t getAddress(std::function<T(U...)> f) {
|
||||
typedef T(fnType)(U...);
|
||||
fnType** fnPointer = f.template target<fnType*>();
|
||||
if (fnPointer == nullptr) {
|
||||
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
||||
using FunType = T (*)(U...);
|
||||
const FunType* fun_ptr = fun.template target<FunType>();
|
||||
if (fun_ptr == nullptr) {
|
||||
throw std::invalid_argument(
|
||||
"[compile] Cannot compile a non-addressable function.");
|
||||
}
|
||||
return (size_t)*fnPointer;
|
||||
return reinterpret_cast<std::uintptr_t>(*fun_ptr);
|
||||
}
|
||||
|
||||
struct CompilerCache {
|
||||
class CompilerCache {
|
||||
public:
|
||||
struct CacheEntry {
|
||||
std::vector<array> inputs;
|
||||
std::vector<array> outputs;
|
||||
@@ -211,20 +218,20 @@ struct CompilerCache {
|
||||
// Returns a reference to a CacheEntry which can be updated
|
||||
// by the caller to avoid copying large tapes / inputs / outputs
|
||||
CacheEntry& find(
|
||||
size_t fun_id,
|
||||
std::uintptr_t fun_id,
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless,
|
||||
const std::vector<uint64_t>& constants) {
|
||||
// Try to find the entry
|
||||
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
|
||||
auto& entries = entry_it->second;
|
||||
auto is_match = [shapeless](
|
||||
const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
// Find the cache entries for |fun_id|.
|
||||
std::vector<CacheEntry>& entries = cache_[fun_id];
|
||||
// Compare if 2 arrays have same shape and dtype.
|
||||
auto has_same_shape_and_dtype = [shapeless](
|
||||
const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
if (in1.size() != in2.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < in1.size(); ++i) {
|
||||
for (size_t i = 0; i < in1.size(); ++i) {
|
||||
if (in1[i].ndim() != in2[i].ndim()) {
|
||||
return false;
|
||||
}
|
||||
@@ -237,14 +244,14 @@ struct CompilerCache {
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// Loop over entries and check inputs match i.e. shapes and types must be
|
||||
// equal. Note this could get really slow if one compiles the same
|
||||
// function with many different shapes. May want to store entries in a
|
||||
// more easily searchable structure.
|
||||
for (auto& entry : entries) {
|
||||
for (CacheEntry& entry : entries) {
|
||||
// Check the inputs match and return if so
|
||||
if (is_match(inputs, entry.inputs) && constants == entry.constants) {
|
||||
if (has_same_shape_and_dtype(inputs, entry.inputs) &&
|
||||
constants == entry.constants) {
|
||||
return entry;
|
||||
}
|
||||
}
|
||||
@@ -253,7 +260,7 @@ struct CompilerCache {
|
||||
return entries.back();
|
||||
};
|
||||
|
||||
void erase(size_t fun_id) {
|
||||
void erase(std::uintptr_t fun_id) {
|
||||
cache_.erase(fun_id);
|
||||
}
|
||||
|
||||
@@ -263,8 +270,9 @@ struct CompilerCache {
|
||||
// initialized before the compiler cache
|
||||
allocator::allocator();
|
||||
}
|
||||
|
||||
friend CompilerCache& compiler_cache();
|
||||
std::unordered_map<size_t, std::vector<CacheEntry>> cache_;
|
||||
std::unordered_map<std::uintptr_t, std::vector<CacheEntry>> cache_;
|
||||
};
|
||||
|
||||
CompilerCache& compiler_cache() {
|
||||
@@ -523,9 +531,14 @@ void compile_fuse(
|
||||
// - Collect inputs to the new compiled primitive
|
||||
// - Add fusable primitives to a tape in the correct order
|
||||
|
||||
std::function<void(const array&, int, const Stream&)> recurse;
|
||||
std::function<void(
|
||||
const array&, int, const Stream&, const std::vector<int>&)>
|
||||
recurse;
|
||||
std::unordered_set<uintptr_t> cache;
|
||||
recurse = [&](const array& a, int depth, const Stream& s) {
|
||||
recurse = [&](const array& a,
|
||||
int depth,
|
||||
const Stream& s,
|
||||
const std::vector<int>& shape) {
|
||||
if (cache.find(a.id()) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
@@ -535,8 +548,10 @@ void compile_fuse(
|
||||
// - Constant input
|
||||
// - Stream mismatch
|
||||
// - Non fusable primitive
|
||||
// - Is global output but has a different shape
|
||||
if (depth >= max_compile_depth || !a.has_primitive() ||
|
||||
a.primitive().stream() != s || !is_fusable(a.primitive())) {
|
||||
a.primitive().stream() != s || !is_fusable(a.primitive()) ||
|
||||
(output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -563,13 +578,13 @@ void compile_fuse(
|
||||
cache.insert({a.id()});
|
||||
|
||||
for (auto& in : a.inputs()) {
|
||||
recurse(in, depth + 1, s);
|
||||
recurse(in, depth + 1, s, shape);
|
||||
}
|
||||
};
|
||||
|
||||
if (arr.has_primitive()) {
|
||||
Stream s = arr.primitive().stream();
|
||||
recurse(arr, 0, s);
|
||||
recurse(arr, 0, s, arr.shape());
|
||||
}
|
||||
|
||||
// Not worth fusing a single primitive
|
||||
@@ -633,6 +648,10 @@ void compile_fuse(
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Dtype> types;
|
||||
for (auto& o : old_outputs) {
|
||||
if (o.shape() != old_outputs.back().shape()) {
|
||||
throw std::runtime_error(
|
||||
"[compile] Compilation failed. Tried to fuse operations with different output shapes");
|
||||
}
|
||||
shapes.push_back(o.shape());
|
||||
types.push_back(o.dtype());
|
||||
}
|
||||
@@ -645,7 +664,7 @@ void compile_fuse(
|
||||
}
|
||||
}
|
||||
auto compiled_outputs = array::make_arrays(
|
||||
shapes,
|
||||
std::move(shapes),
|
||||
types,
|
||||
std::make_shared<Compiled>(
|
||||
old_outputs.back().primitive().stream(),
|
||||
@@ -675,7 +694,7 @@ void compile_fuse(
|
||||
// - Update outputs parents to point to compiled outputs
|
||||
// - Update any overall graph outputs to be compiled outputs
|
||||
for (int o = 0; o < old_outputs.size(); ++o) {
|
||||
merge(compiled_outputs[o], old_outputs[o], parents_map);
|
||||
merge_one(compiled_outputs[o], old_outputs[o], parents_map);
|
||||
if (auto it = output_map.find(old_outputs[o].id());
|
||||
it != output_map.end()) {
|
||||
it->second = compiled_outputs[o];
|
||||
@@ -738,8 +757,8 @@ std::vector<array> compile_replace(
|
||||
shapes.push_back(o.shape());
|
||||
}
|
||||
}
|
||||
auto real_out =
|
||||
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
|
||||
auto real_out = array::make_arrays(
|
||||
std::move(shapes), types, a.primitive_ptr(), real_inputs);
|
||||
for (int i = 0; i < trace_out.size(); ++i) {
|
||||
trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
|
||||
}
|
||||
@@ -774,7 +793,7 @@ void compile_validate_shapeless(const std::vector<array>& tape) {
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless /* = false */,
|
||||
std::vector<uint64_t> constants /* = {} */) {
|
||||
if (compile_mode() == CompileMode::disabled ||
|
||||
@@ -833,7 +852,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
};
|
||||
}
|
||||
|
||||
void compile_erase(size_t fun_id) {
|
||||
void compile_erase(std::uintptr_t fun_id) {
|
||||
detail::compiler_cache().erase(fun_id);
|
||||
}
|
||||
|
||||
@@ -845,7 +864,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
if (detail::compile_mode() == CompileMode::disabled) {
|
||||
return fun;
|
||||
}
|
||||
auto fun_id = detail::getAddress(fun);
|
||||
auto fun_id = detail::get_function_address(fun);
|
||||
return detail::compile(fun, fun_id, shapeless);
|
||||
}
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
@@ -11,9 +11,10 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr int num_types = 13;
|
||||
constexpr int num_types = 13;
|
||||
constexpr int num_cats = 8;
|
||||
|
||||
static constexpr Dtype::Kind type_kinds[num_types] = {
|
||||
constexpr Dtype::Kind type_kinds[num_types] = {
|
||||
Dtype::Kind::b, // bool_,
|
||||
Dtype::Kind::u, // uint8,
|
||||
Dtype::Kind::u, // uint16,
|
||||
@@ -32,7 +33,7 @@ static constexpr Dtype::Kind type_kinds[num_types] = {
|
||||
// Following Jax type promotion rules:
|
||||
// https://jax.readthedocs.io/en/latest/type_promotion.html
|
||||
// clang-format off
|
||||
static constexpr Dtype type_rules[num_types][num_types] = {
|
||||
constexpr Dtype type_rules[num_types][num_types] = {
|
||||
// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64
|
||||
{bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // bool
|
||||
{uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // uint8
|
||||
@@ -49,18 +50,37 @@ static constexpr Dtype type_rules[num_types][num_types] = {
|
||||
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
|
||||
};
|
||||
|
||||
|
||||
constexpr bool subcategory_to_category[num_cats][num_cats] = {
|
||||
// complexfloating floating inexact signedinteger unsignedinteger integer number generic
|
||||
{true, false, true, false, false, false, true, true}, // complexfloating
|
||||
{false, true, true, false, false, false, true, true}, // floating
|
||||
{false, false, true, false, false, false, true, true}, // inexact
|
||||
{false, false, false, true, false, true, true, true}, // signedinteger
|
||||
{false, false, false, false, true, true, true, true}, // unsignedinteger
|
||||
{false, false, false, false, false, true, true, true}, // integer
|
||||
{false, false, false, false, false, false, true, true}, // number
|
||||
{false, false, false, false, false, false, false, true}, // generic
|
||||
};
|
||||
|
||||
constexpr Dtype::Category type_to_category[num_types] = {
|
||||
Dtype::Category::generic, // bool_,
|
||||
Dtype::Category::unsignedinteger, // uint8,
|
||||
Dtype::Category::unsignedinteger, // uint16,
|
||||
Dtype::Category::unsignedinteger, // uint32,
|
||||
Dtype::Category::unsignedinteger, // uint64,
|
||||
Dtype::Category::signedinteger, // int8,
|
||||
Dtype::Category::signedinteger, // int16,
|
||||
Dtype::Category::signedinteger, // int32,
|
||||
Dtype::Category::signedinteger, // int64,
|
||||
Dtype::Category::floating, // float16,
|
||||
Dtype::Category::floating, // float32,
|
||||
Dtype::Category::floating, // bfloat16,
|
||||
Dtype::Category::complexfloating, // complex64,
|
||||
};
|
||||
|
||||
// clang-format on
|
||||
|
||||
inline bool is_big_endian() {
|
||||
union ByteOrder {
|
||||
int32_t i;
|
||||
uint8_t c[4];
|
||||
};
|
||||
ByteOrder b = {0x01234567};
|
||||
|
||||
return b.c[0] == 0x01;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Dtype promote_types(const Dtype& t1, const Dtype& t2) {
|
||||
@@ -141,6 +161,23 @@ TypeToDtype<complex64_t>::operator Dtype() {
|
||||
return complex64;
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype& a, const Dtype& b) {
|
||||
return a == b;
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
|
||||
return issubdtype(type_to_category[static_cast<uint32_t>(type.val)], cat);
|
||||
}
|
||||
|
||||
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
|
||||
return subcategory_to_category[static_cast<uint32_t>(a)]
|
||||
[static_cast<uint32_t>(b)];
|
||||
}
|
||||
|
||||
// Array protocol typestring for Dtype
|
||||
std::string dtype_to_array_protocol(const Dtype& t) {
|
||||
std::ostringstream r;
|
||||
@@ -153,9 +190,9 @@ std::string dtype_to_array_protocol(const Dtype& t) {
|
||||
}
|
||||
|
||||
// Dtype from array protocol type string
|
||||
Dtype dtype_from_array_protocol(const std::string& t) {
|
||||
Dtype dtype_from_array_protocol(std::string_view t) {
|
||||
if (t.length() == 2 || t.length() == 3) {
|
||||
std::string r = t.length() == 3 ? t.substr(1, 2) : t;
|
||||
std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
|
||||
|
||||
if (r == "V2") {
|
||||
return bfloat16;
|
||||
@@ -201,7 +238,7 @@ Dtype dtype_from_array_protocol(const std::string& t) {
|
||||
}
|
||||
|
||||
throw std::invalid_argument(
|
||||
"[from_str] Invalid array protocol type-string: " + t);
|
||||
"[from_str] Invalid array protocol type-string: " + std::string(t));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
74
mlx/dtype.h
74
mlx/dtype.h
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -38,6 +38,17 @@ struct Dtype {
|
||||
V, /* void - used for brain float */
|
||||
};
|
||||
|
||||
enum class Category {
|
||||
complexfloating,
|
||||
floating,
|
||||
inexact,
|
||||
signedinteger,
|
||||
unsignedinteger,
|
||||
integer,
|
||||
number,
|
||||
generic
|
||||
};
|
||||
|
||||
Val val;
|
||||
const uint8_t size;
|
||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
|
||||
@@ -46,22 +57,38 @@ struct Dtype {
|
||||
};
|
||||
};
|
||||
|
||||
static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
||||
inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
||||
|
||||
static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
|
||||
static constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};
|
||||
static constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};
|
||||
static constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};
|
||||
inline constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
|
||||
inline constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};
|
||||
inline constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};
|
||||
inline constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};
|
||||
|
||||
static constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};
|
||||
static constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
|
||||
static constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
|
||||
static constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
|
||||
inline constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};
|
||||
inline constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
|
||||
inline constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
|
||||
inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
|
||||
|
||||
static constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
|
||||
static constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
|
||||
static constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
||||
static constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
||||
inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
|
||||
inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
|
||||
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
||||
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
||||
|
||||
inline constexpr Dtype::Category complexfloating =
|
||||
Dtype::Category::complexfloating;
|
||||
inline constexpr Dtype::Category floating = Dtype::Category::floating;
|
||||
inline constexpr Dtype::Category inexact = Dtype::Category::inexact;
|
||||
inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger;
|
||||
inline constexpr Dtype::Category unsignedinteger =
|
||||
Dtype::Category::unsignedinteger;
|
||||
inline constexpr Dtype::Category integer = Dtype::Category::integer;
|
||||
inline constexpr Dtype::Category number = Dtype::Category::number;
|
||||
inline constexpr Dtype::Category generic = Dtype::Category::generic;
|
||||
|
||||
bool issubdtype(const Dtype& a, const Dtype& b);
|
||||
bool issubdtype(const Dtype::Category& a, const Dtype& b);
|
||||
bool issubdtype(const Dtype& a, const Dtype::Category& b);
|
||||
bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
|
||||
|
||||
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
||||
|
||||
@@ -71,23 +98,6 @@ inline uint8_t size_of(const Dtype& t) {
|
||||
|
||||
Dtype::Kind kindof(const Dtype& t);
|
||||
|
||||
inline bool is_unsigned(const Dtype& t) {
|
||||
return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
|
||||
}
|
||||
|
||||
inline bool is_floating_point(const Dtype& t) {
|
||||
return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
|
||||
kindof(t) == Dtype::Kind::c;
|
||||
}
|
||||
|
||||
inline bool is_complex(const Dtype& t) {
|
||||
return kindof(t) == Dtype::Kind::c;
|
||||
}
|
||||
|
||||
inline bool is_integral(const Dtype& t) {
|
||||
return !(is_floating_point(t));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct TypeToDtype {
|
||||
operator Dtype();
|
||||
@@ -96,6 +106,6 @@ struct TypeToDtype {
|
||||
// Array protocol typestring for Dtype
|
||||
std::string dtype_to_array_protocol(const Dtype& t);
|
||||
// Dtype from array protocol type string
|
||||
Dtype dtype_from_array_protocol(const std::string& t);
|
||||
Dtype dtype_from_array_protocol(std::string_view t);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
397
mlx/fast.cpp
397
mlx/fast.cpp
@@ -1,5 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -46,6 +49,278 @@ std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
||||
return {outputs, out_axes};
|
||||
}
|
||||
|
||||
array rms_norm(
|
||||
const array& x,
|
||||
const array& weight,
|
||||
float eps,
|
||||
StreamOrDevice s_ /* = {} */) {
|
||||
if (x.ndim() == 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] Input must have at least 1 dimension but got input with "
|
||||
"0 dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (weight.ndim() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] weight must have 1 dimension but has " << weight.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto out_type = result_type(x, weight);
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rms_norm] Received unsupported type " << out_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
auto fallback = [eps, out_type, s](const std::vector<array>& inputs) {
|
||||
auto x = astype(inputs[0], float32, s);
|
||||
x = multiply(
|
||||
x,
|
||||
rsqrt(
|
||||
add(mean(square(x, s), -1, /* keepdims */ true, s),
|
||||
array(eps, float32),
|
||||
s),
|
||||
s),
|
||||
s);
|
||||
x = astype(x, out_type, s);
|
||||
return std::vector<array>{multiply(inputs[1], x, s)};
|
||||
};
|
||||
if (s.device == Device::gpu) {
|
||||
return array(
|
||||
x.shape(),
|
||||
out_type,
|
||||
std::make_shared<RMSNorm>(s, fallback, eps),
|
||||
{astype(x, out_type, s), astype(weight, out_type, s)});
|
||||
}
|
||||
return fallback({x, weight})[0];
|
||||
}
|
||||
|
||||
std::vector<array> RMSNorm::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(primals.size() == 2);
|
||||
assert(outputs.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
|
||||
auto s = stream();
|
||||
auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
auto& g = inputs[2];
|
||||
|
||||
std::vector<array> vjps;
|
||||
|
||||
auto n = rsqrt(
|
||||
add(mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s),
|
||||
array(eps, x.dtype()),
|
||||
s),
|
||||
s);
|
||||
auto n3 = power(n, array(3, x.dtype()), s);
|
||||
|
||||
// df/dx
|
||||
auto gw = multiply(g, w, s);
|
||||
auto t = mean(multiply(gw, x, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
t = multiply(multiply(x, t, s), n3, s);
|
||||
vjps.push_back(subtract(multiply(gw, n, s), t, s));
|
||||
|
||||
// df/dw
|
||||
std::vector<int> axes(g.ndim() - 1);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
vjps.push_back(
|
||||
sum(multiply(g, multiply(x, n, s), s), axes, /* keepdims= */ false, s));
|
||||
|
||||
return vjps;
|
||||
};
|
||||
|
||||
auto vjps = array::make_arrays(
|
||||
{primals[0].shape(), primals[1].shape()},
|
||||
{primals[0].dtype(), primals[1].dtype()},
|
||||
std::make_shared<RMSNormVJP>(s, fallback, eps_),
|
||||
{primals[0], primals[1], cotangents[0]});
|
||||
|
||||
std::vector<array> returned_vjps;
|
||||
for (auto& arg : argnums) {
|
||||
returned_vjps.push_back(std::move(vjps[arg]));
|
||||
}
|
||||
|
||||
return returned_vjps;
|
||||
}
|
||||
|
||||
bool RMSNorm::is_equivalent(const Primitive& other) const {
|
||||
const RMSNorm& a_other = static_cast<const RMSNorm&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
bool RMSNormVJP::is_equivalent(const Primitive& other) const {
|
||||
const RMSNormVJP& a_other = static_cast<const RMSNormVJP&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
array layer_norm(
|
||||
const array& x,
|
||||
const std::optional<array>& weight,
|
||||
const std::optional<array>& bias,
|
||||
float eps,
|
||||
StreamOrDevice s_ /* = {} */) {
|
||||
if (x.ndim() == 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] Input must have at least 1 dimension but got input with "
|
||||
"0 dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (weight.has_value() && (*weight).ndim() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] weight must have 1 dimension but has "
|
||||
<< (*weight).ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bias.has_value() && (*bias).ndim() != 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out_type = (weight.has_value())
|
||||
? ((bias.has_value()) ? result_type(x, *weight, *bias)
|
||||
: result_type(x, *weight))
|
||||
: x.dtype();
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[layer_norm] Received unsupported type " << out_type << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
bool has_weight = weight.has_value();
|
||||
bool has_bias = bias.has_value();
|
||||
auto fallback = [has_weight, has_bias, eps, out_type, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto x = astype(inputs[0], float32, s);
|
||||
|
||||
// Should I not be smart here and leave the double mean to simplify()?
|
||||
auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto mu2 = square(mu, s);
|
||||
auto x2 = mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto v = subtract(x2, mu2, s);
|
||||
|
||||
x = multiply(subtract(x, mu, s), rsqrt(add(v, array(eps, float32), s), s));
|
||||
x = astype(x, out_type, s);
|
||||
|
||||
// If the LN is affine then transform x according to the weight and bias
|
||||
if (has_weight) {
|
||||
x = multiply(x, inputs[1], s);
|
||||
}
|
||||
if (has_bias) {
|
||||
x = add(x, inputs[2], s);
|
||||
}
|
||||
|
||||
return std::vector<array>{x};
|
||||
};
|
||||
|
||||
auto passed_weight =
|
||||
astype((weight.has_value()) ? *weight : array(1, out_type), out_type);
|
||||
auto passed_bias =
|
||||
astype((bias.has_value()) ? *bias : array(0, out_type), out_type);
|
||||
|
||||
if (s.device == Device::gpu) {
|
||||
return array(
|
||||
x.shape(),
|
||||
out_type,
|
||||
std::make_shared<LayerNorm>(s, fallback, eps),
|
||||
{astype(x, out_type, s), passed_weight, passed_bias});
|
||||
}
|
||||
return fallback({x, passed_weight, passed_bias})[0];
|
||||
}
|
||||
|
||||
std::vector<array> LayerNorm::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(primals.size() == 3);
|
||||
assert(outputs.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
|
||||
auto s = stream();
|
||||
auto fallback = [eps = eps_, s](const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
auto& b = inputs[2];
|
||||
auto& g = inputs[3];
|
||||
|
||||
std::vector<array> vjps;
|
||||
|
||||
auto norm = number_of_elements(x, {-1}, true, x.dtype(), s);
|
||||
auto sumx = sum(x, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto sumx2 = sum(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
auto mu = multiply(sumx, norm, s);
|
||||
auto mu2 = multiply(sumx2, norm, s);
|
||||
auto var = subtract(mu2, square(mu, s), s);
|
||||
auto n = rsqrt(add(var, array(eps, x.dtype()), s));
|
||||
auto n3 = power(n, array(3, x.dtype()), s);
|
||||
auto x_c = subtract(x, mu, s);
|
||||
|
||||
// df/dx
|
||||
auto wg = multiply(w, g, s);
|
||||
auto sumwg =
|
||||
multiply(sum(wg, /* axis= */ -1, /* keepdims= */ true, s), norm, s);
|
||||
auto sumwgxc = multiply(
|
||||
sum(multiply(wg, x_c, s), /* axis= */ -1, /* keepdims= */ true, s),
|
||||
norm,
|
||||
s);
|
||||
auto t1 = multiply(multiply(x_c, sumwgxc, s), n3, s);
|
||||
auto t2 = multiply(subtract(wg, sumwg, s), n, s);
|
||||
vjps.push_back(subtract(t2, t1, s));
|
||||
|
||||
// df/dw
|
||||
std::vector<int> axes(g.ndim() - 1);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
if (w.ndim() == 0) {
|
||||
vjps.push_back(zeros_like(w, s));
|
||||
} else {
|
||||
vjps.push_back(sum(
|
||||
multiply(g, multiply(x_c, n, s), s), axes, /* keepdims= */ false, s));
|
||||
}
|
||||
|
||||
// df/db
|
||||
if (b.ndim() == 0) {
|
||||
vjps.push_back(zeros_like(w, s));
|
||||
} else {
|
||||
vjps.push_back(sum(g, axes, /* keepdims= */ false, s));
|
||||
}
|
||||
|
||||
return vjps;
|
||||
};
|
||||
|
||||
auto vjps = array::make_arrays(
|
||||
{primals[0].shape(), primals[1].shape(), primals[2].shape()},
|
||||
{primals[0].dtype(), primals[1].dtype(), primals[2].dtype()},
|
||||
std::make_shared<LayerNormVJP>(s, fallback, eps_),
|
||||
{primals[0], primals[1], primals[2], cotangents[0]});
|
||||
|
||||
std::vector<array> returned_vjps;
|
||||
for (auto& arg : argnums) {
|
||||
returned_vjps.push_back(std::move(vjps[arg]));
|
||||
}
|
||||
|
||||
return returned_vjps;
|
||||
}
|
||||
|
||||
bool LayerNorm::is_equivalent(const Primitive& other) const {
|
||||
const LayerNorm& a_other = static_cast<const LayerNorm&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
bool LayerNormVJP::is_equivalent(const Primitive& other) const {
|
||||
const LayerNormVJP& a_other = static_cast<const LayerNormVJP&>(other);
|
||||
return eps_ == a_other.eps_;
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
@@ -53,21 +328,20 @@ array rope(
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() != 3) {
|
||||
bool forward,
|
||||
StreamOrDevice s) {
|
||||
if (x.ndim() < 3) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim()
|
||||
<< " dimensions.";
|
||||
msg << "[rope] Input must have at least 3 dimensions but got input with "
|
||||
<< x.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (traditional && x.shape(-1) != dims) {
|
||||
throw std::invalid_argument(
|
||||
"[rope] Does not support partial traditional application.");
|
||||
}
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, offset, s](
|
||||
auto fallback = [dims, traditional, base, scale, offset, forward, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& shape = inputs[0].shape();
|
||||
int ndim = shape.size();
|
||||
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
|
||||
auto t = x.dtype();
|
||||
auto N = x.shape(1) + offset;
|
||||
// Compute sines and cosines
|
||||
@@ -80,16 +354,39 @@ array rope(
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
|
||||
if (traditional) {
|
||||
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s);
|
||||
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s);
|
||||
auto apply_rope = [forward, s](
|
||||
const array& x1,
|
||||
const array& x2,
|
||||
const array& coss,
|
||||
const array& sins) {
|
||||
std::vector<array> outs;
|
||||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
if (forward) {
|
||||
outs.push_back(
|
||||
subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
} else {
|
||||
outs.push_back(add(multiply(x2, sins, s), multiply(x1, coss, s), s));
|
||||
outs.push_back(
|
||||
subtract(multiply(x2, coss, s), multiply(x1, sins, s), s));
|
||||
}
|
||||
return outs;
|
||||
};
|
||||
|
||||
if (traditional) {
|
||||
auto x1 =
|
||||
slice(x, {0, 0, 0}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto x2 =
|
||||
slice(x, {0, 0, 1}, {x.shape(0), x.shape(1), dims}, {1, 1, 2}, s);
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
for (auto& o : outs) {
|
||||
o = expand_dims(o, 3, s);
|
||||
}
|
||||
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)};
|
||||
auto out = concatenate(outs, 3, s);
|
||||
if (dims < x.shape(-1)) {
|
||||
out = reshape(out, {x.shape(0), x.shape(1), dims});
|
||||
out = concatenate({out, slice(x, {0, 0, dims}, x.shape(), s)}, 2, s);
|
||||
}
|
||||
return std::vector<array>{reshape(out, shape, s)};
|
||||
} else {
|
||||
auto out_s = x.shape();
|
||||
out_s.back() = half_dims;
|
||||
@@ -97,33 +394,67 @@ array rope(
|
||||
out_s.back() = dims;
|
||||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
|
||||
|
||||
std::vector<array> outs;
|
||||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
auto outs = apply_rope(x1, x2, coss, sins);
|
||||
if (dims < x.shape(-1)) {
|
||||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
||||
}
|
||||
return std::vector<array>{concatenate(outs, 2, s)};
|
||||
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)};
|
||||
}
|
||||
};
|
||||
auto stream = to_stream(s);
|
||||
if (stream.device == Device::gpu && x.shape(-1) == dims) {
|
||||
if (stream.device == Device::gpu) {
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_unique<RoPE>(
|
||||
stream, fallback, dims, traditional, base, scale, offset),
|
||||
std::make_shared<RoPE>(
|
||||
stream, fallback, dims, traditional, base, scale, offset, forward),
|
||||
{x});
|
||||
}
|
||||
return fallback({x})[0];
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return rope(x, dims, traditional, base, scale, offset, true, s);
|
||||
}
|
||||
|
||||
std::vector<array> RoPE::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto s = stream();
|
||||
auto fallback = [dims = dims_,
|
||||
traditional = traditional_,
|
||||
base = base_,
|
||||
scale = scale_,
|
||||
offset = offset_,
|
||||
forward = forward_,
|
||||
s](std::vector<array> inputs) {
|
||||
return std::vector<array>{
|
||||
rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)};
|
||||
};
|
||||
|
||||
return {array(
|
||||
cotangents[0].shape(),
|
||||
cotangents[0].dtype(),
|
||||
std::make_shared<RoPE>(
|
||||
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
|
||||
cotangents)};
|
||||
}
|
||||
|
||||
bool RoPE::is_equivalent(const Primitive& other) const {
|
||||
const RoPE& a_other = static_cast<const RoPE&>(other);
|
||||
return (
|
||||
dims_ == a_other.dims_ && base_ == a_other.base_ &&
|
||||
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
|
||||
offset_ == a_other.offset_);
|
||||
offset_ == a_other.offset_ && forward_ == a_other.forward_);
|
||||
}
|
||||
|
||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||
@@ -181,8 +512,8 @@ array scaled_dot_product_attention(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto final_type = result_type({queries, keys, values});
|
||||
if (!is_floating_point(final_type) || is_complex(final_type)) {
|
||||
auto final_type = result_type(queries, keys, values);
|
||||
if (!issubdtype(final_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[scaled_dot_product_attention] Received unsupported type "
|
||||
<< final_type << ".";
|
||||
@@ -193,9 +524,6 @@ array scaled_dot_product_attention(
|
||||
auto k = astype(keys, final_type, s);
|
||||
auto v = astype(values, final_type, s);
|
||||
|
||||
auto out_shape =
|
||||
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
|
||||
|
||||
/* generic implementation for use cases that Metal implementation does not
|
||||
* support. For non-supported cases listed below, use MLX primitives:
|
||||
* * CPU implementation
|
||||
@@ -222,10 +550,7 @@ array scaled_dot_product_attention(
|
||||
if (needs_mask) {
|
||||
scores = add(scores, inputs[3], s);
|
||||
}
|
||||
scores = astype(
|
||||
softmax(astype(scores, float32, s), std::vector<int>{-1}, s),
|
||||
final_type,
|
||||
s);
|
||||
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||
auto out = matmul(scores, v, s);
|
||||
if (n_repeats > 1) {
|
||||
out = reshape(out, {B, n_q_heads, L, -1}, s);
|
||||
@@ -244,10 +569,12 @@ array scaled_dot_product_attention(
|
||||
// TODO, update routing conditions post further tuning
|
||||
implementation_supports_use_case &= false;
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape =
|
||||
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
|
||||
auto out = array(
|
||||
out_shape,
|
||||
std::move(out_shape),
|
||||
final_type,
|
||||
std::make_unique<ScaledDotProductAttention>(
|
||||
std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, false),
|
||||
{q, k, v});
|
||||
return out;
|
||||
|
15
mlx/fast.h
15
mlx/fast.h
@@ -8,6 +8,19 @@
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
array rms_norm(
|
||||
const array& x,
|
||||
const array& weight,
|
||||
float eps,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array layer_norm(
|
||||
const array& x,
|
||||
const std::optional<array>& weight,
|
||||
const std::optional<array>& bias,
|
||||
float eps,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
@@ -15,7 +28,7 @@ array rope(
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */);
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||
array scaled_dot_product_attention(
|
||||
|
@@ -1,3 +1,5 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
@@ -31,6 +33,110 @@ class Custom : public Primitive {
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
};
|
||||
|
||||
class RMSNorm : public Custom {
|
||||
public:
|
||||
RMSNorm(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(RMSNorm)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class RMSNormVJP : public Custom {
|
||||
public:
|
||||
RMSNormVJP(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(RMSNormVJP)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class LayerNorm : public Custom {
|
||||
public:
|
||||
LayerNorm(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(LayerNorm)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class LayerNormVJP : public Custom {
|
||||
public:
|
||||
LayerNormVJP(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
float eps)
|
||||
: Custom(stream, fallback), eps_(eps){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(LayerNormVJP)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
class RoPE : public Custom {
|
||||
public:
|
||||
RoPE(
|
||||
@@ -40,19 +146,29 @@ class RoPE : public Custom {
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset)
|
||||
int offset,
|
||||
bool forward)
|
||||
: Custom(stream, fallback),
|
||||
dims_(dims),
|
||||
traditional_(traditional),
|
||||
base_(base),
|
||||
scale_(scale),
|
||||
offset_(offset){};
|
||||
offset_(offset),
|
||||
forward_(forward){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(RoPE)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
@@ -63,6 +179,7 @@ class RoPE : public Custom {
|
||||
float base_;
|
||||
float scale_;
|
||||
int offset_;
|
||||
bool forward_;
|
||||
};
|
||||
|
||||
class ScaledDotProductAttention : public Custom {
|
||||
@@ -76,7 +193,7 @@ class ScaledDotProductAttention : public Custom {
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
outputs[0] = fallback_(inputs)[0];
|
||||
throw std::runtime_error("NYI");
|
||||
};
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
|
@@ -95,7 +95,7 @@ array fft_impl(
|
||||
return array(
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_unique<FFT>(to_stream(s), valid_axes, inverse, real),
|
||||
std::make_shared<FFT>(to_stream(s), valid_axes, inverse, real),
|
||||
{astype(in, in_type, s)});
|
||||
}
|
||||
|
||||
|
@@ -6,12 +6,10 @@
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "stream.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mlx::core::fft {
|
||||
|
||||
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
||||
|
||||
/** Compute the n-dimensional Fourier Transform. */
|
||||
array fftn(
|
||||
const array& a,
|
||||
|
@@ -23,8 +23,7 @@ const std::string& NodeNamer::get_name(const array& x) {
|
||||
letters.push_back('A' + (var_num - 1) % 26);
|
||||
var_num = (var_num - 1) / 26;
|
||||
}
|
||||
std::string name(letters.rbegin(), letters.rend());
|
||||
names.insert({x.id(), name});
|
||||
names.emplace(x.id(), std::string(letters.rbegin(), letters.rend()));
|
||||
|
||||
return get_name(x);
|
||||
}
|
||||
|
@@ -14,15 +14,15 @@ struct NodeNamer {
|
||||
|
||||
void print_graph(std::ostream& os, const std::vector<array>& outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
void print_graph(std::ostream& os, Arrays... outputs) {
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
void print_graph(std::ostream& os, Arrays&&... outputs) {
|
||||
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
void export_to_dot(std::ostream& os, Arrays... outputs) {
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
void export_to_dot(std::ostream& os, Arrays&&... outputs) {
|
||||
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
|
8
mlx/io.h
8
mlx/io.h
@@ -23,13 +23,13 @@ using SafetensorsLoad = std::pair<
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a);
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
void save(const std::string& file, array a);
|
||||
void save(std::string file, array a);
|
||||
|
||||
/** Load array from reader in .npy format */
|
||||
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
||||
|
||||
/** Load array from file in .npy format */
|
||||
array load(const std::string& file, StreamOrDevice s = {});
|
||||
array load(std::string file, StreamOrDevice s = {});
|
||||
|
||||
/** Load array map from .safetensors file format */
|
||||
SafetensorsLoad load_safetensors(
|
||||
@@ -44,13 +44,13 @@ void save_safetensors(
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string> metadata = {});
|
||||
void save_safetensors(
|
||||
const std::string& file,
|
||||
std::string file,
|
||||
std::unordered_map<std::string, array>,
|
||||
std::unordered_map<std::string, std::string> metadata = {});
|
||||
|
||||
/** Load array map and metadata from .gguf file format */
|
||||
|
||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
|
||||
GGUFLoad load_gguf(std::string_view file, StreamOrDevice s = {});
|
||||
|
||||
void save_gguf(
|
||||
std::string file,
|
||||
|
@@ -206,7 +206,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
||||
std::unordered_map<std::string, array> array_map;
|
||||
gguf_tensor tensor;
|
||||
|
||||
auto check_insert = [](auto inserted) {
|
||||
auto check_insert = [](const auto& inserted) {
|
||||
if (!inserted.second) {
|
||||
std::ostringstream msg;
|
||||
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
|
||||
@@ -216,6 +216,7 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
||||
};
|
||||
|
||||
while (gguf_get_tensor(ctx, &tensor)) {
|
||||
std::string name(tensor.name, tensor.namelen);
|
||||
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||
|
||||
tensor.type == GGUF_TYPE_Q8_0) {
|
||||
gguf_load_quantized(array_map, tensor);
|
||||
@@ -224,14 +225,14 @@ std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
|
||||
|
||||
const auto& [data, dtype] = extract_tensor_data(&tensor);
|
||||
array loaded_array = array(data, get_shape(tensor), dtype);
|
||||
array_map.insert({name, loaded_array});
|
||||
check_insert(array_map.insert({name, loaded_array}));
|
||||
}
|
||||
}
|
||||
return array_map;
|
||||
}
|
||||
|
||||
GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
|
||||
gguf_ctx* ctx = gguf_open(file.c_str());
|
||||
GGUFLoad load_gguf(std::string_view file, StreamOrDevice s) {
|
||||
gguf_ctx* ctx = gguf_open(file.data());
|
||||
if (!ctx) {
|
||||
throw std::runtime_error("[load_gguf] gguf_init failed");
|
||||
}
|
||||
|
@@ -105,7 +105,8 @@ void gguf_load_quantized(
|
||||
weights_per_byte = 1;
|
||||
}
|
||||
|
||||
std::string name = std::string(tensor.name, tensor.namelen);
|
||||
std::string name(tensor.name, tensor.namelen);
|
||||
|
||||
std::vector<int> shape = get_shape(tensor);
|
||||
const uint64_t weights_per_block = 32;
|
||||
if (shape[shape.size() - 1] % weights_per_block != 0) {
|
||||
@@ -136,9 +137,9 @@ void gguf_load_quantized(
|
||||
extract_q8_0_data(tensor, weights, scales, biases);
|
||||
}
|
||||
|
||||
a.insert({name, weights});
|
||||
a.emplace(name, std::move(weights));
|
||||
|
||||
auto check_insert = [](auto inserted) {
|
||||
auto check_insert = [](const auto& inserted) {
|
||||
if (!inserted.second) {
|
||||
std::ostringstream msg;
|
||||
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
|
||||
@@ -147,11 +148,11 @@ void gguf_load_quantized(
|
||||
}
|
||||
};
|
||||
|
||||
const std::string weight_suffix = ".weight";
|
||||
constexpr std::string_view weight_suffix = ".weight";
|
||||
const std::string name_prefix =
|
||||
name.substr(0, name.length() - weight_suffix.length());
|
||||
check_insert(a.insert({name_prefix + ".scales", scales}));
|
||||
check_insert(a.insert({name_prefix + ".biases", biases}));
|
||||
check_insert(a.emplace(name_prefix + ".scales", std::move(scales)));
|
||||
check_insert(a.emplace(name_prefix + ".biases", std::move(biases)));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
@@ -18,7 +17,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr uint8_t MAGIC[] = {
|
||||
constexpr uint8_t MAGIC[] = {
|
||||
0x93,
|
||||
0x4e,
|
||||
0x55,
|
||||
@@ -27,16 +26,6 @@ static constexpr uint8_t MAGIC[] = {
|
||||
0x59,
|
||||
};
|
||||
|
||||
inline bool is_big_endian_() {
|
||||
union ByteOrder {
|
||||
int32_t i;
|
||||
uint8_t c[4];
|
||||
};
|
||||
ByteOrder b = {0x01234567};
|
||||
|
||||
return b.c[0] == 0x01;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
@@ -95,7 +84,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
uint16_t v1_header_len = header.tellp();
|
||||
const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);
|
||||
|
||||
if (!is_big_endian_()) {
|
||||
if (!is_big_endian()) {
|
||||
magic_ver_len.write(len_bytes, 2);
|
||||
} else {
|
||||
magic_ver_len.write(len_bytes + 1, 1);
|
||||
@@ -107,7 +96,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
uint32_t v2_header_len = header.tellp();
|
||||
const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);
|
||||
|
||||
if (!is_big_endian_()) {
|
||||
if (!is_big_endian()) {
|
||||
magic_ver_len.write(len_bytes, 4);
|
||||
} else {
|
||||
magic_ver_len.write(len_bytes + 3, 1);
|
||||
@@ -122,21 +111,16 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length());
|
||||
out_stream->write(header.str().c_str(), header.str().length());
|
||||
out_stream->write(a.data<char>(), a.nbytes());
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
void save(const std::string& file_, array a) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
void save(std::string file, array a) {
|
||||
// Add .npy to file name if it is not there
|
||||
if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy")
|
||||
file += ".npy";
|
||||
|
||||
// Serialize array
|
||||
save(std::make_shared<io::FileWriter>(file), a);
|
||||
save(std::make_shared<io::FileWriter>(std::move(file)), a);
|
||||
}
|
||||
|
||||
/** Load array from reader in .npy format */
|
||||
@@ -222,7 +206,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
// Build primitive
|
||||
|
||||
size_t offset = 8 + header_len_size + header.length();
|
||||
bool swap_endianness = read_is_big_endian != is_big_endian_();
|
||||
bool swap_endianness = read_is_big_endian != is_big_endian();
|
||||
|
||||
if (col_contiguous) {
|
||||
std::reverse(shape.begin(), shape.end());
|
||||
@@ -230,7 +214,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
auto loaded_array = array(
|
||||
shape,
|
||||
dtype,
|
||||
std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
||||
std::make_shared<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
||||
std::vector<array>{});
|
||||
if (col_contiguous) {
|
||||
loaded_array = transpose(loaded_array, s);
|
||||
@@ -240,8 +224,8 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
}
|
||||
|
||||
/** Load array from file in .npy format */
|
||||
array load(const std::string& file, StreamOrDevice s) {
|
||||
return load(std::make_shared<io::FileReader>(file), s);
|
||||
array load(std::string file, StreamOrDevice s) {
|
||||
return load(std::make_shared<io::FileReader>(std::move(file)), s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -14,7 +14,7 @@ class Reader {
|
||||
public:
|
||||
virtual bool is_open() const = 0;
|
||||
virtual bool good() const = 0;
|
||||
virtual size_t tell() const = 0;
|
||||
virtual size_t tell() = 0; // tellp is non-const in iostream
|
||||
virtual void seek(
|
||||
int64_t off,
|
||||
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
||||
@@ -26,7 +26,7 @@ class Writer {
|
||||
public:
|
||||
virtual bool is_open() const = 0;
|
||||
virtual bool good() const = 0;
|
||||
virtual size_t tell() const = 0;
|
||||
virtual size_t tell() = 0;
|
||||
virtual void seek(
|
||||
int64_t off,
|
||||
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
||||
@@ -36,31 +36,31 @@ class Writer {
|
||||
|
||||
class FileReader : public Reader {
|
||||
public:
|
||||
explicit FileReader(const std::shared_ptr<std::ifstream>& is)
|
||||
: is_(is), label_("stream") {}
|
||||
explicit FileReader(const std::string& file_path)
|
||||
: is_(std::make_shared<std::ifstream>(file_path, std::ios::binary)),
|
||||
label_(file_path) {}
|
||||
explicit FileReader(std::ifstream is)
|
||||
: is_(std::move(is)), label_("stream") {}
|
||||
explicit FileReader(std::string file_path)
|
||||
: is_(std::ifstream(file_path, std::ios::binary)),
|
||||
label_(std::move(file_path)) {}
|
||||
|
||||
bool is_open() const override {
|
||||
return is_->is_open();
|
||||
return is_.is_open();
|
||||
}
|
||||
|
||||
bool good() const override {
|
||||
return is_->good();
|
||||
return is_.good();
|
||||
}
|
||||
|
||||
size_t tell() const override {
|
||||
return is_->tellg();
|
||||
size_t tell() override {
|
||||
return is_.tellg();
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
is_->seekg(off, way);
|
||||
is_.seekg(off, way);
|
||||
}
|
||||
|
||||
void read(char* data, size_t n) override {
|
||||
is_->read(data, n);
|
||||
is_.read(data, n);
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
@@ -68,37 +68,37 @@ class FileReader : public Reader {
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<std::ifstream> is_;
|
||||
std::ifstream is_;
|
||||
std::string label_;
|
||||
};
|
||||
|
||||
class FileWriter : public Writer {
|
||||
public:
|
||||
explicit FileWriter(const std::shared_ptr<std::ofstream>& is)
|
||||
: os_(is), label_("stream") {}
|
||||
explicit FileWriter(const std::string& file_path)
|
||||
: os_(std::make_shared<std::ofstream>(file_path, std::ios::binary)),
|
||||
label_(file_path) {}
|
||||
explicit FileWriter(std::ofstream os)
|
||||
: os_(std::move(os)), label_("stream") {}
|
||||
explicit FileWriter(std::string file_path)
|
||||
: os_(std::ofstream(file_path, std::ios::binary)),
|
||||
label_(std::move(file_path)) {}
|
||||
|
||||
bool is_open() const override {
|
||||
return os_->is_open();
|
||||
return os_.is_open();
|
||||
}
|
||||
|
||||
bool good() const override {
|
||||
return os_->good();
|
||||
return os_.good();
|
||||
}
|
||||
|
||||
size_t tell() const override {
|
||||
return os_->tellp();
|
||||
size_t tell() override {
|
||||
return os_.tellp();
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
os_->seekp(off, way);
|
||||
os_.seekp(off, way);
|
||||
}
|
||||
|
||||
void write(const char* data, size_t n) override {
|
||||
os_->write(data, n);
|
||||
os_.write(data, n);
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
@@ -106,7 +106,7 @@ class FileWriter : public Writer {
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<std::ofstream> os_;
|
||||
std::ofstream os_;
|
||||
std::string label_;
|
||||
};
|
||||
|
||||
|
@@ -60,7 +60,7 @@ std::string dtype_to_safetensor_str(Dtype t) {
|
||||
}
|
||||
}
|
||||
|
||||
Dtype dtype_from_safetensor_str(std::string str) {
|
||||
Dtype dtype_from_safetensor_str(std::string_view str) {
|
||||
if (str == ST_F32) {
|
||||
return float32;
|
||||
} else if (str == ST_F16) {
|
||||
@@ -88,7 +88,8 @@ Dtype dtype_from_safetensor_str(std::string str) {
|
||||
} else if (str == ST_C64) {
|
||||
return complex64;
|
||||
} else {
|
||||
throw std::runtime_error("[safetensor] unsupported dtype " + str);
|
||||
throw std::runtime_error(
|
||||
"[safetensor] unsupported dtype " + std::string(str));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,14 +130,14 @@ SafetensorsLoad load_safetensors(
|
||||
}
|
||||
continue;
|
||||
}
|
||||
std::string dtype = item.value().at("dtype");
|
||||
std::vector<int> shape = item.value().at("shape");
|
||||
std::vector<size_t> data_offsets = item.value().at("data_offsets");
|
||||
const std::string& dtype = item.value().at("dtype");
|
||||
const std::vector<int>& shape = item.value().at("shape");
|
||||
const std::vector<size_t>& data_offsets = item.value().at("data_offsets");
|
||||
Dtype type = dtype_from_safetensor_str(dtype);
|
||||
auto loaded_array = array(
|
||||
shape,
|
||||
type,
|
||||
std::make_unique<Load>(
|
||||
std::make_shared<Load>(
|
||||
to_stream(s), in_stream, offset + data_offsets.at(0), false),
|
||||
std::vector<array>{});
|
||||
res.insert({item.key(), loaded_array});
|
||||
@@ -207,19 +208,17 @@ void save_safetensors(
|
||||
}
|
||||
|
||||
void save_safetensors(
|
||||
const std::string& file_,
|
||||
std::string file,
|
||||
std::unordered_map<std::string, array> a,
|
||||
std::unordered_map<std::string, std::string> metadata /* = {} */) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
// Add .safetensors to file name if it is not there
|
||||
if (file.length() < 12 ||
|
||||
file.substr(file.length() - 12, 12) != ".safetensors")
|
||||
file += ".safetensors";
|
||||
|
||||
// Serialize array
|
||||
save_safetensors(std::make_shared<io::FileWriter>(file), a, metadata);
|
||||
save_safetensors(
|
||||
std::make_shared<io::FileWriter>(std::move(file)), a, metadata);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -11,7 +11,7 @@
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||
}
|
||||
|
||||
inline array l2_norm(
|
||||
@@ -19,7 +19,7 @@ inline array l2_norm(
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (is_complex(a.dtype())) {
|
||||
if (issubdtype(a.dtype(), complexfloating)) {
|
||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
|
||||
} else {
|
||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||
@@ -96,7 +96,7 @@ inline array matrix_norm(
|
||||
|
||||
inline array matrix_norm(
|
||||
const array& a,
|
||||
const std::string& ord,
|
||||
std::string_view ord,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
@@ -153,7 +153,7 @@ array norm(
|
||||
|
||||
array norm(
|
||||
const array& a,
|
||||
const std::string& ord,
|
||||
std::string_view ord,
|
||||
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
@@ -195,7 +195,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||
auto out = array::make_arrays(
|
||||
{a.shape(), a.shape()},
|
||||
{a.dtype(), a.dtype()},
|
||||
std::make_unique<QRF>(to_stream(s)),
|
||||
std::make_shared<QRF>(to_stream(s)),
|
||||
{astype(a, a.dtype(), s)});
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
@@ -234,8 +234,31 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return array::make_arrays(
|
||||
{u_shape, s_shape, vt_shape},
|
||||
{a.dtype(), a.dtype(), a.dtype()},
|
||||
std::make_unique<SVD>(to_stream(s)),
|
||||
std::make_shared<SVD>(to_stream(s)),
|
||||
{a});
|
||||
}
|
||||
|
||||
array inv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::inv] Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::inv] Arrays must have >= 2 dimensions. Received array "
|
||||
"with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (a.shape(-1) != a.shape(-2)) {
|
||||
throw std::invalid_argument(
|
||||
"[linalg::inv] Inverses are only defined for square matrices.");
|
||||
}
|
||||
|
||||
return array(
|
||||
a.shape(), a.dtype(), std::make_shared<Inverse>(to_stream(s)), {a});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
@@ -38,13 +38,13 @@ inline array norm(
|
||||
}
|
||||
array norm(
|
||||
const array& a,
|
||||
const std::string& ord,
|
||||
std::string_view ord,
|
||||
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
inline array norm(
|
||||
const array& a,
|
||||
const std::string& ord,
|
||||
std::string_view ord,
|
||||
int axis,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {}) {
|
||||
@@ -64,4 +64,6 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s = {});
|
||||
|
||||
std::vector<array> svd(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array inv(const array& a, StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
715
mlx/ops.cpp
715
mlx/ops.cpp
File diff suppressed because it is too large
Load Diff
56
mlx/ops.h
56
mlx/ops.h
@@ -41,40 +41,33 @@ array linspace(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Convert an array to the given data type. */
|
||||
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
|
||||
array astype(array a, Dtype dtype, StreamOrDevice s = {});
|
||||
|
||||
/** Create a view of an array with the given shape and strides. */
|
||||
array as_strided(
|
||||
const array& a,
|
||||
array a,
|
||||
std::vector<int> shape,
|
||||
std::vector<size_t> strides,
|
||||
size_t offset,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Copy another array. */
|
||||
array copy(const array& a, StreamOrDevice s = {});
|
||||
array copy(array a, StreamOrDevice s = {});
|
||||
|
||||
/** Fill an array of the given shape with the given value(s). */
|
||||
array full(
|
||||
const std::vector<int>& shape,
|
||||
const array& vals,
|
||||
std::vector<int> shape,
|
||||
array vals,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s = {});
|
||||
array full(
|
||||
const std::vector<int>& shape,
|
||||
const array& vals,
|
||||
StreamOrDevice s = {});
|
||||
array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
|
||||
template <typename T>
|
||||
array full(
|
||||
const std::vector<int>& shape,
|
||||
T val,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s = {}) {
|
||||
return full(shape, array(val, dtype), to_stream(s));
|
||||
array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) {
|
||||
return full(std::move(shape), array(val, dtype), to_stream(s));
|
||||
}
|
||||
template <typename T>
|
||||
array full(const std::vector<int>& shape, T val, StreamOrDevice s = {}) {
|
||||
return full(shape, array(val), to_stream(s));
|
||||
array full(std::vector<int> shape, T val, StreamOrDevice s = {}) {
|
||||
return full(std::move(shape), array(val), to_stream(s));
|
||||
}
|
||||
|
||||
/** Fill an array of the given shape with zeros. */
|
||||
@@ -158,9 +151,7 @@ array expand_dims(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Add a singleton dimension at the given axis. */
|
||||
inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) {
|
||||
return expand_dims(a, std::vector<int>{axis}, s);
|
||||
}
|
||||
array expand_dims(const array& a, int axis, StreamOrDevice s = {});
|
||||
|
||||
/** Slice an array. */
|
||||
array slice(
|
||||
@@ -177,6 +168,23 @@ array slice(
|
||||
const std::vector<int>& stop,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array */
|
||||
array slice_update(
|
||||
const array& src,
|
||||
const array& update,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
std::vector<int> strides,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Update a slice from the source array with stride 1 in each dimension */
|
||||
array slice_update(
|
||||
const array& src,
|
||||
const array& update,
|
||||
std::vector<int> start,
|
||||
std::vector<int> stop,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Split an array into sub-arrays along a given axis. */
|
||||
std::vector<array>
|
||||
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
|
||||
@@ -968,14 +976,16 @@ array rsqrt(const array& a, StreamOrDevice s = {});
|
||||
array softmax(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool precise = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Softmax of an array. */
|
||||
array softmax(const array& a, StreamOrDevice s = {});
|
||||
array softmax(const array& a, bool precise = false, StreamOrDevice s = {});
|
||||
|
||||
/** Softmax of an array. */
|
||||
inline array softmax(const array& a, int axis, StreamOrDevice s = {}) {
|
||||
return softmax(a, std::vector<int>{axis}, s);
|
||||
inline array
|
||||
softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
|
||||
return softmax(a, std::vector<int>{axis}, precise, s);
|
||||
}
|
||||
|
||||
/** Raise elements of a to the power of b element-wise */
|
||||
|
@@ -110,7 +110,11 @@ std::vector<array> Primitive::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&) {
|
||||
throw std::invalid_argument("Primitive's jvp not implemented.");
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::jvp] Not implemented for ";
|
||||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
|
||||
std::vector<array> Primitive::vjp(
|
||||
@@ -118,13 +122,21 @@ std::vector<array> Primitive::vjp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
throw std::invalid_argument("Primitive's vjp not implemented.");
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::vip] Not implemented for ";
|
||||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&) {
|
||||
throw std::invalid_argument("Primitive's vmap not implemented.");
|
||||
std::ostringstream msg;
|
||||
msg << "[Primitive::vmap] Not implemented for ";
|
||||
print(msg);
|
||||
msg << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
};
|
||||
|
||||
std::vector<std::vector<int>> Primitive::output_shapes(
|
||||
@@ -235,6 +247,18 @@ bool AddMM::is_equivalent(const Primitive& other) const {
|
||||
return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> AddMM::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto maybe_move_ax = [this](auto& arr, auto ax) {
|
||||
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
|
||||
};
|
||||
auto a = maybe_move_ax(inputs[0], axes[0]);
|
||||
auto b = maybe_move_ax(inputs[1], axes[1]);
|
||||
auto c = maybe_move_ax(inputs[2], axes[2]);
|
||||
return {{addmm(c, a, b, alpha_, beta_, stream())}, {0}};
|
||||
}
|
||||
|
||||
bool Arange::is_equivalent(const Primitive& other) const {
|
||||
const Arange& a_other = static_cast<const Arange&>(other);
|
||||
return (
|
||||
@@ -1103,7 +1127,7 @@ std::pair<std::vector<array>, std::vector<int>> Equal::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {{equal(a, b, stream())}, axes};
|
||||
return {{equal(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Equal::vjp(
|
||||
@@ -1243,7 +1267,7 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
|
||||
{array(
|
||||
out_shape,
|
||||
real_ && inverse_ ? float32 : complex64,
|
||||
std::make_unique<FFT>(stream(), fft_axes, inverse_, real_),
|
||||
std::make_shared<FFT>(stream(), fft_axes, inverse_, real_),
|
||||
{in})},
|
||||
{ax}};
|
||||
}
|
||||
@@ -1353,7 +1377,7 @@ std::pair<std::vector<array>, std::vector<int>> Full::vmap(
|
||||
assert(axes.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto out =
|
||||
array(in.shape(), in.dtype(), std::make_unique<Full>(stream()), {in});
|
||||
array(in.shape(), in.dtype(), std::make_shared<Full>(stream()), {in});
|
||||
return {{out}, axes};
|
||||
}
|
||||
|
||||
@@ -1444,7 +1468,7 @@ std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {{greater(a, b, stream())}, axes};
|
||||
return {{greater(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Greater::vjp(
|
||||
@@ -1471,7 +1495,7 @@ std::pair<std::vector<array>, std::vector<int>> GreaterEqual::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {{greater_equal(a, b, stream())}, axes};
|
||||
return {{greater_equal(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> GreaterEqual::vjp(
|
||||
@@ -1498,7 +1522,7 @@ std::pair<std::vector<array>, std::vector<int>> Less::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {{less(a, b, stream())}, axes};
|
||||
return {{less(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Less::vjp(
|
||||
@@ -1525,7 +1549,7 @@ std::pair<std::vector<array>, std::vector<int>> LessEqual::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
|
||||
return {{less_equal(a, b, stream())}, axes};
|
||||
return {{less_equal(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> LessEqual::vjp(
|
||||
@@ -1580,7 +1604,7 @@ std::pair<std::vector<array>, std::vector<int>> Log::vmap(
|
||||
{array(
|
||||
in.shape(),
|
||||
in.dtype(),
|
||||
std::make_unique<Log>(stream(), base_),
|
||||
std::make_shared<Log>(stream(), base_),
|
||||
{in})},
|
||||
axes};
|
||||
}
|
||||
@@ -1772,6 +1796,17 @@ std::vector<array> Matmul::vjp(
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Matmul::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto maybe_move_ax = [this](auto& arr, auto ax) {
|
||||
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
|
||||
};
|
||||
auto a = maybe_move_ax(inputs[0], axes[0]);
|
||||
auto b = maybe_move_ax(inputs[1], axes[1]);
|
||||
return {{matmul(a, b, stream())}, {0}};
|
||||
}
|
||||
|
||||
std::vector<array> Maximum::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
@@ -2224,7 +2259,7 @@ std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(
|
||||
auto out = array(
|
||||
shape,
|
||||
get_dtype(),
|
||||
std::make_unique<RandomBits>(stream(), shape, width_),
|
||||
std::make_shared<RandomBits>(stream(), shape, width_),
|
||||
{key});
|
||||
return {{out}, {kax}};
|
||||
}
|
||||
@@ -2458,7 +2493,7 @@ std::pair<std::vector<array>, std::vector<int>> Scan::vmap(
|
||||
{array(
|
||||
in.shape(),
|
||||
out_dtype,
|
||||
std::make_unique<Scan>(
|
||||
std::make_shared<Scan>(
|
||||
stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_),
|
||||
{in})},
|
||||
axes};
|
||||
@@ -2555,8 +2590,11 @@ std::vector<array> Scatter::vjp(
|
||||
break;
|
||||
case Scatter::Max:
|
||||
case Scatter::Min: {
|
||||
auto mask = where(result == values, array({1}), array({0}));
|
||||
vjps.push_back(multiply(cotangents[0], mask));
|
||||
vjps.push_back(where(
|
||||
equal(result, values, stream()),
|
||||
cotangents[0],
|
||||
array(0, cotangents[0].dtype()),
|
||||
stream()));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -2814,6 +2852,114 @@ bool Slice::is_equivalent(const Primitive& other) const {
|
||||
end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> SliceUpdate::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 2);
|
||||
assert(axes.size() == 2);
|
||||
|
||||
auto start = start_indices_;
|
||||
auto stop = end_indices_;
|
||||
auto strides = strides_;
|
||||
|
||||
auto src = inputs[0];
|
||||
auto upd = inputs[1];
|
||||
|
||||
auto src_ax = axes[0];
|
||||
auto upd_ax = axes[1];
|
||||
|
||||
// No vmapping needed
|
||||
if (src_ax == -1 && upd_ax == -1) {
|
||||
return {{slice_update(src, upd, start, stop, strides, stream())}, {-1}};
|
||||
}
|
||||
|
||||
// Broadcast src
|
||||
if (src_ax == -1) {
|
||||
src = expand_dims(src, upd_ax, stream());
|
||||
auto shape = src.shape();
|
||||
shape[upd_ax] = upd.shape(upd_ax);
|
||||
src = broadcast_to(src, shape, stream());
|
||||
src_ax = upd_ax;
|
||||
}
|
||||
|
||||
// Broadcast upd
|
||||
if (upd_ax == -1) {
|
||||
upd = expand_dims(upd, src_ax, stream());
|
||||
upd_ax = src_ax;
|
||||
}
|
||||
|
||||
if (src_ax != upd_ax) {
|
||||
upd = moveaxis(upd, upd_ax, src_ax, stream());
|
||||
}
|
||||
|
||||
start.insert(start.begin() + src_ax, 0);
|
||||
stop.insert(stop.begin() + src_ax, src.shape(src_ax));
|
||||
strides.insert(strides.begin() + src_ax, 1);
|
||||
|
||||
return {{slice_update(src, upd, start, stop, strides, stream())}, {src_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> SliceUpdate::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
// Check inputs
|
||||
assert(primals.size() == 2);
|
||||
|
||||
auto& cotan = cotangents[0];
|
||||
auto& src = primals[0];
|
||||
auto& upd = primals[1];
|
||||
|
||||
std::vector<array> vjps;
|
||||
|
||||
for (int num : argnums) {
|
||||
// Vjp for source
|
||||
if (num == 0) {
|
||||
auto grad = slice_update(
|
||||
cotan,
|
||||
zeros_like(upd, stream()),
|
||||
start_indices_,
|
||||
end_indices_,
|
||||
strides_,
|
||||
stream());
|
||||
|
||||
vjps.push_back(grad);
|
||||
}
|
||||
// Vjp fpr updates
|
||||
else {
|
||||
auto grad =
|
||||
slice(cotan, start_indices_, end_indices_, strides_, stream());
|
||||
|
||||
vjps.push_back(grad);
|
||||
}
|
||||
}
|
||||
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> SliceUpdate::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Check inputs
|
||||
assert(primals.size() == 2);
|
||||
return {slice_update(
|
||||
tangents[0],
|
||||
tangents[1],
|
||||
start_indices_,
|
||||
end_indices_,
|
||||
strides_,
|
||||
stream())};
|
||||
}
|
||||
|
||||
bool SliceUpdate::is_equivalent(const Primitive& other) const {
|
||||
const SliceUpdate& s_other = static_cast<const SliceUpdate&>(other);
|
||||
return (
|
||||
start_indices_ == s_other.start_indices_ &&
|
||||
end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@@ -2829,7 +2975,7 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
|
||||
} else {
|
||||
softmax_axes.push_back(-2);
|
||||
}
|
||||
return {{softmax(inputs[0], softmax_axes, stream())}, axes};
|
||||
return {{softmax(inputs[0], softmax_axes, precise_, stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> Softmax::vjp(
|
||||
@@ -2852,13 +2998,18 @@ std::vector<array> Softmax::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(tangents.size() == 1);
|
||||
auto s = softmax(primals[0], std::vector<int>{-1}, stream());
|
||||
auto s = softmax(primals[0], std::vector<int>{-1}, precise_, stream());
|
||||
auto sv = multiply(s, tangents[0], stream());
|
||||
return {subtract(
|
||||
sv,
|
||||
multiply(s, sum(sv, std::vector<int>{-1}, true, stream()), stream()))};
|
||||
}
|
||||
|
||||
bool Softmax::is_equivalent(const Primitive& other) const {
|
||||
const Softmax& s_other = static_cast<const Softmax&>(other);
|
||||
return precise_ == s_other.precise_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@@ -3157,7 +3308,7 @@ std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
|
||||
array out = array(
|
||||
std::vector<int>{},
|
||||
dtype_,
|
||||
std::make_unique<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
|
||||
std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
|
||||
inputs);
|
||||
|
||||
return {{out}, {-1}};
|
||||
|
@@ -200,6 +200,7 @@ class AddMM : public UnaryPrimitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(AddMM)
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
@@ -419,12 +420,12 @@ class AsStrided : public UnaryPrimitive {
|
||||
public:
|
||||
explicit AsStrided(
|
||||
Stream stream,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides,
|
||||
std::vector<int> shape,
|
||||
std::vector<size_t> strides,
|
||||
size_t offset)
|
||||
: UnaryPrimitive(stream),
|
||||
shape_(shape),
|
||||
strides_(strides),
|
||||
shape_(std::move(shape)),
|
||||
strides_(std::move(strides)),
|
||||
offset_(offset){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
@@ -1140,6 +1141,7 @@ class Matmul : public UnaryPrimitive {
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Matmul)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
};
|
||||
@@ -1658,11 +1660,50 @@ class Slice : public UnaryPrimitive {
|
||||
std::vector<int> strides_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in);
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out);
|
||||
};
|
||||
|
||||
class SliceUpdate : public UnaryPrimitive {
|
||||
public:
|
||||
explicit SliceUpdate(
|
||||
Stream stream,
|
||||
const std::vector<int>& start_indices,
|
||||
const std::vector<int>& end_indices,
|
||||
const std::vector<int>& strides)
|
||||
: UnaryPrimitive(stream),
|
||||
start_indices_(start_indices),
|
||||
end_indices_(end_indices),
|
||||
strides_(strides){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(SliceUpdate)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
std::vector<int> start_indices_;
|
||||
std::vector<int> end_indices_;
|
||||
std::vector<int> strides_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const array& in);
|
||||
};
|
||||
|
||||
class Softmax : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Softmax(Stream stream) : UnaryPrimitive(stream){};
|
||||
explicit Softmax(Stream stream, bool precise)
|
||||
: UnaryPrimitive(stream), precise_(precise){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@@ -1670,11 +1711,13 @@ class Softmax : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Softmax)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
bool precise_;
|
||||
};
|
||||
|
||||
class Sort : public UnaryPrimitive {
|
||||
@@ -1889,10 +1932,26 @@ class SVD : public Primitive {
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(SVD)
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
/* Matrix inversion primitive. */
|
||||
class Inverse : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Inverse(Stream stream) : UnaryPrimitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& output) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& output) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_PRINT(Inverse)
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& output);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -66,7 +66,7 @@ array bits(
|
||||
return array(
|
||||
shape,
|
||||
get_dtype(),
|
||||
std::make_unique<RandomBits>(to_stream(s), shape, width),
|
||||
std::make_shared<RandomBits>(to_stream(s), shape, width),
|
||||
{key});
|
||||
}
|
||||
|
||||
@@ -90,6 +90,16 @@ T below_one() {
|
||||
return f;
|
||||
}
|
||||
|
||||
// Get the next representable value above -1.0 for half precision
|
||||
// floating point types (fp16, bf16)
|
||||
template <typename T>
|
||||
T above_minus_one() {
|
||||
T f = T(-1.0);
|
||||
uint16_t* m = (uint16_t*)&f;
|
||||
*m -= 1;
|
||||
return f;
|
||||
}
|
||||
|
||||
array uniform(
|
||||
const array& low,
|
||||
const array& high,
|
||||
@@ -97,7 +107,7 @@ array uniform(
|
||||
Dtype dtype /* = float32 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!is_floating_point(dtype) && !is_complex(dtype)) {
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
throw std::invalid_argument(
|
||||
"Can only generate uniform numbers with real floating point type.");
|
||||
}
|
||||
@@ -158,7 +168,17 @@ array normal(
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto stream = to_stream(s);
|
||||
auto low = array(std::nextafter(-1.0f, 0.0f), dtype);
|
||||
auto get_low = [&dtype]() {
|
||||
switch (dtype) {
|
||||
case float16:
|
||||
return array(above_minus_one<float16_t>(), dtype);
|
||||
case bfloat16:
|
||||
return array(above_minus_one<bfloat16_t>(), dtype);
|
||||
default:
|
||||
return array(std::nextafter(-1.0f, 0.0f), dtype);
|
||||
}
|
||||
};
|
||||
auto low = get_low();
|
||||
auto high = array(1.0f, dtype);
|
||||
auto samples = uniform(low, high, shape, dtype, key, stream);
|
||||
samples =
|
||||
@@ -179,7 +199,7 @@ array randint(
|
||||
Dtype dtype /* = int32 */,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!is_integral(dtype)) {
|
||||
if (issubdtype(dtype, inexact)) {
|
||||
throw std::invalid_argument(
|
||||
"[randint] randint only accepts integer dtypes and bool.");
|
||||
}
|
||||
@@ -192,7 +212,7 @@ array bernoulli(
|
||||
const std::vector<int>& shape,
|
||||
const std::optional<array>& key /*= nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!is_floating_point(p.dtype())) {
|
||||
if (!issubdtype(p.dtype(), floating)) {
|
||||
throw std::invalid_argument(
|
||||
"[bernoulli] bernoulli probability `p` must be a float type.");
|
||||
}
|
||||
@@ -228,7 +248,7 @@ array truncated_normal(
|
||||
// Same as
|
||||
// https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal
|
||||
|
||||
if (!is_floating_point(dtype)) {
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
throw std::invalid_argument(
|
||||
"[trunc_normal] trunc_normal only accepts floating point dtypes.");
|
||||
}
|
||||
|
@@ -28,7 +28,8 @@ class Synchronizer : public Primitive {
|
||||
|
||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||
void print(std::ostream&) override {}
|
||||
|
||||
DEFINE_PRINT(Synchronize);
|
||||
};
|
||||
|
||||
// Initialize the static tracing counter from transforms_impl.h .
|
||||
@@ -37,7 +38,7 @@ class Synchronizer : public Primitive {
|
||||
// are currently under a function transformation.
|
||||
int detail::InTracing::tracing_counter{0};
|
||||
|
||||
void eval(const std::vector<array>& outputs) {
|
||||
void eval(std::vector<array> outputs) {
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
@@ -52,8 +53,8 @@ void eval(const std::vector<array>& outputs) {
|
||||
}
|
||||
}
|
||||
|
||||
auto synchronizer =
|
||||
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
|
||||
auto synchronizer = array(
|
||||
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
|
||||
|
||||
size_t depth_counter = 0;
|
||||
recurse = [&](const array& a) {
|
||||
@@ -76,7 +77,7 @@ void eval(const std::vector<array>& outputs) {
|
||||
// If the input is being computed on a different stream, we need to
|
||||
// manage the dependency.
|
||||
if (a.primitive().stream() != in.primitive().stream()) {
|
||||
deps.insert({in.primitive_id(), std::shared_future<void>{}});
|
||||
deps.insert({in.output(0).id(), std::shared_future<void>{}});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -96,8 +97,7 @@ void eval(const std::vector<array>& outputs) {
|
||||
};
|
||||
|
||||
recurse(synchronizer);
|
||||
uintptr_t synch_id = synchronizer.primitive_id();
|
||||
deps.insert({synch_id, std::shared_future<void>{}});
|
||||
deps.insert({synchronizer.id(), std::shared_future<void>{}});
|
||||
|
||||
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
||||
while (!tape.empty()) {
|
||||
@@ -113,14 +113,13 @@ void eval(const std::vector<array>& outputs) {
|
||||
auto stream = arr.primitive().stream();
|
||||
std::vector<std::shared_future<void>> arr_deps;
|
||||
for (auto& in : arr.inputs()) {
|
||||
// TODO that's a bug
|
||||
if (auto it = deps.find(in.primitive_id()); it != deps.end()) {
|
||||
if (auto it = deps.find(in.output(0).id()); it != deps.end()) {
|
||||
arr_deps.push_back(it->second);
|
||||
}
|
||||
}
|
||||
std::shared_ptr<std::promise<void>> p{nullptr};
|
||||
if (auto it = deps.find(arr.primitive_id()); it != deps.end()) {
|
||||
p = std::make_unique<std::promise<void>>();
|
||||
std::shared_ptr<std::promise<void>> p;
|
||||
if (auto it = deps.find(arr.output(0).id()); it != deps.end()) {
|
||||
p = std::make_shared<std::promise<void>>();
|
||||
ps.push_back(p);
|
||||
it->second = p->get_future().share();
|
||||
}
|
||||
@@ -154,7 +153,7 @@ void eval(const std::vector<array>& outputs) {
|
||||
}
|
||||
}
|
||||
|
||||
deps[synch_id].wait();
|
||||
deps[synchronizer.id()].wait();
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
@@ -655,6 +654,7 @@ std::vector<array> vmap_replace(
|
||||
}
|
||||
|
||||
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
|
||||
|
||||
// For each primitive's outputs add its id, the vout id and the vax
|
||||
auto outputs = a.outputs();
|
||||
for (int i = 0; i < v_outputs.size(); ++i) {
|
||||
@@ -781,7 +781,7 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
shapes,
|
||||
std::move(shapes),
|
||||
dtypes,
|
||||
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
|
||||
inputs);
|
||||
|
@@ -6,10 +6,10 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void eval(const std::vector<array>& outputs);
|
||||
void eval(std::vector<array> outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
void eval(Arrays... outputs) {
|
||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||
void eval(Arrays&&... outputs) {
|
||||
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
|
@@ -20,12 +20,12 @@ std::vector<array> vmap_replace(
|
||||
// idea.
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless = false,
|
||||
std::vector<uint64_t> constants = {});
|
||||
|
||||
// Erase cached compile functions
|
||||
void compile_erase(size_t fun_id);
|
||||
void compile_erase(std::uintptr_t fun_id);
|
||||
|
||||
// Create an InTracing object during tracing operations to signify to the rest
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user