mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
angelos's commit files
This commit is contained in:
parent
8ca7f9e8e9
commit
d1f86272a2
75
.gitignore
vendored
Normal file
75
.gitignore
vendored
Normal file
@ -0,0 +1,75 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Metal libraries
|
||||
*.metallib
|
||||
|
||||
# Distribution / packaging
|
||||
python/mlx/share
|
||||
python/mlx/include
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
||||
# Ignore build dir
|
||||
build/
|
||||
|
||||
# Prerequisites
|
||||
*.d
|
||||
|
||||
# Compiled Object files
|
||||
*.slo
|
||||
*.lo
|
||||
*.o
|
||||
*.obj
|
||||
|
||||
# Precompiled Headers
|
||||
*.gch
|
||||
*.pch
|
||||
|
||||
# Compiled Dynamic libraries
|
||||
*.so
|
||||
*.dylib
|
||||
*.dll
|
||||
|
||||
# Fortran module files
|
||||
*.mod
|
||||
*.smod
|
||||
|
||||
# Compiled Static libraries
|
||||
*.lai
|
||||
*.la
|
||||
*.a
|
||||
*.lib
|
||||
|
||||
# Executables
|
||||
*.exe
|
||||
*.out
|
||||
*.app
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
.DS_Store
|
9
.pre-commit-config.yaml
Normal file
9
.pre-commit-config.yaml
Normal file
@ -0,0 +1,9 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v14.0.6
|
||||
hooks:
|
||||
- id: clang-format
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.10.0
|
||||
hooks:
|
||||
- id: black
|
197
CMakeLists.txt
Normal file
197
CMakeLists.txt
Normal file
@ -0,0 +1,197 @@
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
|
||||
project(mlx LANGUAGES CXX)
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||
|
||||
# ----------------------------- Configuration -----------------------------
|
||||
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
||||
option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.0.1)
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
|
||||
include(FetchContent)
|
||||
# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
|
||||
if (MLX_BUILD_METAL)
|
||||
find_library(METAL_LIB Metal)
|
||||
find_library(FOUNDATION_LIB Foundation)
|
||||
find_library(QUARTZ_LIB QuartzCore)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
elseif (MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
add_compile_definitions(_METAL_)
|
||||
|
||||
execute_process(COMMAND zsh "-c" "/usr/bin/sw_vers | cut -f2- -d: | sed -n 2p | grep -Eo '[0-9]+.[0-9]+'"
|
||||
OUTPUT_VARIABLE MACOS_VERSION)
|
||||
|
||||
message(STATUS "Detected macOS version ${MACOS_VERSION}")
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
|
||||
else()
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13_iOS16.zip)
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
metal_cpp
|
||||
URL ${METAL_CPP_URL}
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
target_include_directories(
|
||||
mlx PUBLIC
|
||||
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>
|
||||
)
|
||||
target_link_libraries(
|
||||
mlx
|
||||
${METAL_LIB}
|
||||
${FOUNDATION_LIB}
|
||||
${QUARTZ_LIB})
|
||||
endif()
|
||||
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if (ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
message(STATUS "Accelerate not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
#set(BLA_VENDOR Generic)
|
||||
find_package(BLAS REQUIRED)
|
||||
if (NOT BLAS_FOUND)
|
||||
message(FATAL_ERROR "Must have BLAS installed")
|
||||
endif()
|
||||
# TODO find a cleaner way to do this
|
||||
find_path(BLAS_INCLUDE_DIRS cblas.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
$ENV{BLAS_HOME}/include)
|
||||
message(STATUS ${BLAS_LIBRARIES})
|
||||
message(STATUS ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
||||
target_include_directories(
|
||||
mlx
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
|
||||
if (MLX_BUILD_PYTHON_BINDINGS)
|
||||
message(STATUS "Building Python bindings.")
|
||||
find_package(Python COMPONENTS Interpreter Development)
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_TESTS)
|
||||
include(CTest)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_EXAMPLES)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_BENCHMARKS)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
||||
endif()
|
||||
|
||||
# ----------------------------- Installation -----------------------------
|
||||
include(GNUInstallDirs)
|
||||
|
||||
# Install library
|
||||
install(
|
||||
TARGETS mlx
|
||||
EXPORT MLXTargets
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
)
|
||||
|
||||
|
||||
# Install headers
|
||||
install(
|
||||
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
COMPONENT headers
|
||||
FILES_MATCHING PATTERN "*.h"
|
||||
)
|
||||
|
||||
# Install metal dependencies
|
||||
if (MLX_BUILD_METAL)
|
||||
|
||||
# Install metal cpp
|
||||
install(
|
||||
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
||||
COMPONENT metal_cpp_source
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
# Install cmake config
|
||||
set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake)
|
||||
set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake)
|
||||
set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
|
||||
|
||||
install(
|
||||
EXPORT MLXTargets
|
||||
FILE MLXTargets.cmake
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
)
|
||||
|
||||
include(CMakePackageConfigHelpers)
|
||||
|
||||
write_basic_package_version_file(
|
||||
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||
COMPATIBILITY SameMajorVersion
|
||||
VERSION ${MLX_VERSION}
|
||||
)
|
||||
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
|
||||
${MLX_CMAKE_BUILD_CONFIG}
|
||||
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
||||
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
|
||||
)
|
||||
|
||||
install(
|
||||
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY ${CMAKE_MODULE_PATH}/
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
)
|
38
benchmarks/cpp/time_utils.h
Normal file
38
benchmarks/cpp/time_utils.h
Normal file
@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
#define milliseconds(x) \
|
||||
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e6)
|
||||
#define time_now() std::chrono::high_resolution_clock::now()
|
||||
|
||||
#define TIME(FUNC, ...) \
|
||||
std::cout << "Timing " << #FUNC << " ... " << std::flush \
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
|
||||
#define TIMEM(MSG, FUNC, ...) \
|
||||
std::cout << "Timing " \
|
||||
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
|
||||
template <typename F, typename... Args>
|
||||
double time_fn(F fn, Args... args) {
|
||||
// warmup
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
eval(fn(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
int num_iters = 100;
|
||||
auto start = time_now();
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
eval(fn(std::forward<Args>(args)...));
|
||||
}
|
||||
auto end = time_now();
|
||||
return milliseconds(end - start) / static_cast<double>(num_iters);
|
||||
}
|
60
benchmarks/python/batch_matmul_bench.py
Normal file
60
benchmarks/python/batch_matmul_bench.py
Normal file
@ -0,0 +1,60 @@
|
||||
import argparse
|
||||
import mlx.core as mx
|
||||
|
||||
from time_utils import time_fn
|
||||
|
||||
B = 8
|
||||
T = 1024
|
||||
D = 512
|
||||
|
||||
|
||||
def time_batch_matmul():
|
||||
mx.random.seed(3)
|
||||
a = mx.random.uniform(shape=(B, T, D))
|
||||
b = mx.random.uniform(shape=(D, D))
|
||||
c = mx.random.uniform(shape=(B, T, D))
|
||||
mx.eval(a, b, c)
|
||||
|
||||
time_fn(mx.matmul, a, b)
|
||||
|
||||
def batch_vjp_first():
|
||||
return mx.vjp(mx.matmul, [a, b], [c])[1][0]
|
||||
|
||||
time_fn(batch_vjp_first)
|
||||
|
||||
def batch_vjp_second():
|
||||
return mx.vjp(mx.matmul, [a, b], [c])[1][1]
|
||||
|
||||
time_fn(batch_vjp_second)
|
||||
|
||||
|
||||
def time_unbatch_matmul(key):
|
||||
mx.random.seed(3)
|
||||
a = mx.random.uniform(shape=(B * T, D))
|
||||
b = mx.random.uniform(shape=(D, D))
|
||||
c = mx.random.uniform(shape=(B * T, D))
|
||||
mx.eval(a, b, c)
|
||||
time_fn(mx.matmul, a, b)
|
||||
|
||||
def unbatch_vjp_first():
|
||||
return mx.matmul(c, mx.transpose(b))
|
||||
|
||||
time_fn(unbatch_vjp_first)
|
||||
|
||||
def unbatch_vjp_second():
|
||||
return mx.matmul(mx.transpose(a), c)
|
||||
|
||||
time_fn(unbatch_vjp_second)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
time_batch_matmul()
|
||||
time_unbatch_matmul()
|
20
benchmarks/python/time_utils.py
Normal file
20
benchmarks/python/time_utils.py
Normal file
@ -0,0 +1,20 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def time_fn(fn, *args, **kwargs):
|
||||
print(f"Timing {fn.__name__} ...", end=" ")
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(*args, **kwargs))
|
||||
|
||||
num_iters = 100
|
||||
tic = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
x = mx.eval(fn(*args, **kwargs))
|
||||
toc = time.perf_counter()
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
print(f"{msec:.5f} msec")
|
2
docs/.clang-format
Normal file
2
docs/.clang-format
Normal file
@ -0,0 +1,2 @@
|
||||
DisableFormat: true
|
||||
SortIncludes: Never
|
18
docs/Makefile
Normal file
18
docs/Makefile
Normal file
@ -0,0 +1,18 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
|
||||
# You can set these variables from the command line.
|
||||
SPHINXOPTS =
|
||||
SPHINXBUILD = sphinx-build
|
||||
SOURCEDIR = src
|
||||
BUILDDIR = build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
131
docs/src/examples/mlp.rst
Normal file
131
docs/src/examples/mlp.rst
Normal file
@ -0,0 +1,131 @@
|
||||
.. _mlp:
|
||||
|
||||
Multi-Layer Perceptron
|
||||
----------------------
|
||||
|
||||
In this example we'll learn to use ``mlx.nn`` by implementing a simple
|
||||
multi-layer perceptron to classify MNIST.
|
||||
|
||||
As a first step import the MLX packages we need:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
The model is defined as the ``MLP`` class which inherits from
|
||||
:class:`mlx.nn.Module`. We follow the standard idiom to make a new module:
|
||||
|
||||
1. Define an ``__init__`` where the parameters and/or submodules are setup. See
|
||||
the :ref:`Module class docs<module_class>` for more information on how
|
||||
:class:`mlx.nn.Module` registers parameters.
|
||||
2. Define a ``__call__`` where the computation is implemented.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = mx.maximum(l(x), 0.0)
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
We define the loss function which takes the mean of the per-example cross
|
||||
entropy loss. The ``mlx.nn.losses`` sub-package has implementations of some
|
||||
commonly used loss functions.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def loss_fn(model, X, y):
|
||||
return mx.mean(nn.losses.cross_entropy(model(X), y))
|
||||
|
||||
We also need a function to compute the accuracy of the model on the validation
|
||||
set:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def eval_fn(model, X, y):
|
||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||
|
||||
Next, setup the problem parameters and load the data:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
num_layers = 2
|
||||
hidden_dim = 32
|
||||
num_classes = 10
|
||||
batch_size = 256
|
||||
num_epochs = 10
|
||||
learning_rate = 1e-1
|
||||
|
||||
# Load the data
|
||||
import mnist
|
||||
train_images, train_labels, test_images, test_labels = map(
|
||||
mx.array, mnist.mnist()
|
||||
)
|
||||
|
||||
Since we're using SGD, we need an iterator which shuffles and constructs
|
||||
minibatches of examples in the training set:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def batch_iterate(batch_size, X, y):
|
||||
perm = mx.array(np.random.permutation(y.size))
|
||||
for s in range(0, y.size, batch_size):
|
||||
ids = perm[s : s + batch_size]
|
||||
yield X[ids], y[ids]
|
||||
|
||||
|
||||
Finally, we put it all together by instantiating the model, the
|
||||
:class:`mlx.optimizers.SGD` optimizer, and running the training loop:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Load the model
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Get a function which gives the loss and gradient of the
|
||||
# loss with respect to the model's trainable parameters
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
# Instantiate the optimizer
|
||||
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||
|
||||
for e in range(num_epochs):
|
||||
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||
loss, grads = loss_and_grad_fn(model, X, y)
|
||||
|
||||
# Update the optimizer state and model parameters
|
||||
# in a single call
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Force a graph evaluation
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
accuracy = eval_fn(model, test_images, test_labels)
|
||||
print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}")
|
||||
|
||||
|
||||
.. note::
|
||||
The :func:`mlx.nn.value_and_grad` function is a convenience function to get
|
||||
the gradient of a loss with respect to the trainable parameters of a model.
|
||||
This should not be confused with :func:`mlx.core.value_and_grad`.
|
||||
|
||||
The model should train to a decent accuracy (about 95%) after just a few passes
|
||||
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mlp>`_
|
||||
is available in the MLX GitHub repo.
|
45
docs/src/python/array.rst
Normal file
45
docs/src/python/array.rst
Normal file
@ -0,0 +1,45 @@
|
||||
.. _array:
|
||||
|
||||
Array
|
||||
=====
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
array
|
||||
array.astype
|
||||
array.item
|
||||
array.tolist
|
||||
array.dtype
|
||||
array.ndim
|
||||
array.shape
|
||||
array.size
|
||||
Dtype
|
||||
array.abs
|
||||
array.all
|
||||
array.any
|
||||
array.argmax
|
||||
array.argmin
|
||||
array.cos
|
||||
array.dtype
|
||||
array.exp
|
||||
array.log
|
||||
array.log1p
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
array.min
|
||||
array.prod
|
||||
array.reciprocal
|
||||
array.reshape
|
||||
array.rsqrt
|
||||
array.sin
|
||||
array.split
|
||||
array.sqrt
|
||||
array.square
|
||||
array.sum
|
||||
array.transpose
|
||||
array.T
|
||||
array.var
|
94
docs/src/python/ops.rst
Normal file
94
docs/src/python/ops.rst
Normal file
@ -0,0 +1,94 @@
|
||||
.. _ops:
|
||||
|
||||
Operations
|
||||
==========
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
abs
|
||||
add
|
||||
all
|
||||
allclose
|
||||
any
|
||||
arange
|
||||
arccos
|
||||
arccosh
|
||||
arcsin
|
||||
arcsinh
|
||||
arctan
|
||||
arctanh
|
||||
argmax
|
||||
argmin
|
||||
argpartition
|
||||
argsort
|
||||
array_equal
|
||||
broadcast_to
|
||||
concatenate
|
||||
convolve
|
||||
conv1d
|
||||
conv2d
|
||||
cos
|
||||
cosh
|
||||
divide
|
||||
equal
|
||||
erf
|
||||
erfinv
|
||||
exp
|
||||
expand_dims
|
||||
full
|
||||
greater
|
||||
greater_equal
|
||||
less
|
||||
less_equal
|
||||
load
|
||||
log
|
||||
log2
|
||||
log10
|
||||
log1p
|
||||
logaddexp
|
||||
logical_not
|
||||
logsumexp
|
||||
matmul
|
||||
max
|
||||
maximum
|
||||
mean
|
||||
min
|
||||
minimum
|
||||
multiply
|
||||
negative
|
||||
ones
|
||||
ones_like
|
||||
partition
|
||||
pad
|
||||
prod
|
||||
reciprocal
|
||||
reshape
|
||||
rsqrt
|
||||
save
|
||||
savez
|
||||
savez_compressed
|
||||
sigmoid
|
||||
sign
|
||||
sin
|
||||
sinh
|
||||
softmax
|
||||
sort
|
||||
split
|
||||
sqrt
|
||||
square
|
||||
squeeze
|
||||
stop_gradient
|
||||
subtract
|
||||
sum
|
||||
take
|
||||
take_along_axis
|
||||
tan
|
||||
tanh
|
||||
transpose
|
||||
var
|
||||
where
|
||||
zeros
|
||||
zeros_like
|
18
examples/cpp/timer.h
Normal file
18
examples/cpp/timer.h
Normal file
@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
|
||||
namespace timer {
|
||||
|
||||
using namespace std::chrono;
|
||||
|
||||
template <typename R, typename P>
|
||||
inline double seconds(duration<R, P> x) {
|
||||
return duration_cast<nanoseconds>(x).count() / 1e9;
|
||||
}
|
||||
|
||||
inline auto time() {
|
||||
return high_resolution_clock::now();
|
||||
}
|
||||
|
||||
} // namespace timer
|
359
examples/extensions/axpby/axpby.cpp
Normal file
359
examples/extensions/axpby/axpby.cpp
Normal file
@ -0,0 +1,359 @@
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include "axpby/axpby.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <vecLib/cblas_new.h>
|
||||
#endif
|
||||
|
||||
#ifdef _METAL_
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* Scale and sum two vectors elementwise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Promote dtypes between x and y as needed
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = is_floating_point(promoted_dtype)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
|
||||
// Cast x and y up to the determined dtype (on the same stream s)
|
||||
auto x_casted = astype(x, out_dtype, s);
|
||||
auto y_casted = astype(y, out_dtype, s);
|
||||
|
||||
// Broadcast the shapes of x and y (on the same stream s)
|
||||
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||
auto out_shape = broadcasted_inputs[0].shape();
|
||||
|
||||
// Construct the array as the output of the Axpby primitive
|
||||
// with the broadcasted and upcasted arrays as inputs
|
||||
return array(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive Common Backend Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
const T* y_ptr = y.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Do the elementwise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additonal mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
}
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||
// Check the inputs (registered in the op while contructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive Accelerate Backend Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
|
||||
// This specialization requires both x and y be contiguous in the same mode
|
||||
// i.e: corresponding linear indices in both point to corresponding elements
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
||||
y.data_size(),
|
||||
y.strides(),
|
||||
y.flags());
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
T* y_ptr = out.data<T>();
|
||||
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Call the inplace accelerate operator
|
||||
catlas_saxpby(
|
||||
/* N = */ out.size(),
|
||||
/* ALPHA = */ alpha,
|
||||
/* X = */ x_ptr,
|
||||
/* INCX = */ 1,
|
||||
/* BETA = */ beta,
|
||||
/* Y = */ y_ptr,
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
#else // Accelerate not avaliable
|
||||
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive Metal Backend Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef _METAL_
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
|
||||
// Each primitive carries the stream it should execute on
|
||||
// and each stream carries its device identifiers
|
||||
auto& s = stream();
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Prepare to specialize based on contiguity
|
||||
bool contiguous_kernel =
|
||||
(x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous);
|
||||
|
||||
// Allocate output memory with strides based on specialization
|
||||
if (contiguous_kernel) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_";
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available and look for it
|
||||
// in the same folder as this executable if needed
|
||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel decelaration at axpby.metal
|
||||
int ndim = out.ndim();
|
||||
size_t nelem = out.size();
|
||||
|
||||
// Encode input arrays to kernel
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, y, 1);
|
||||
|
||||
// Encode output arrays to kernel
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
|
||||
// Encode shape, strides and ndim if needed
|
||||
if (!contiguous_kernel) {
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
}
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
// threads in any given threadgroup is not higher than the max allowed
|
||||
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Fix the 3D size of each threadgroup (in terms of threads)
|
||||
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
|
||||
|
||||
// Fix the 3D size of the launch grid (in terms of threads)
|
||||
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||
|
||||
// Launch the grid with the given number of threads divded among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#else // Metal is not available
|
||||
|
||||
/** Fail evaluation on GPU */
|
||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error("Axpby has no GPU implementation.");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive Transforms
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primtive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
// jvp is just the tangent scaled by alpha
|
||||
// Similarly, if argnums = {1}, the jvp is just the tangent
|
||||
// scaled by beta
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return multiply(scale_arr, tangents[0], stream());
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
|
||||
}
|
||||
}
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotan.dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotan, stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
/** Vectorize primitve along given axis */
|
||||
std::pair<array, int> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
bool Axpby::is_equivalent(const Primitive& other) const {
|
||||
const Axpby& r_other = static_cast<const Axpby&>(other);
|
||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
39
examples/extensions/bindings.cpp
Normal file
39
examples/extensions/bindings.cpp
Normal file
@ -0,0 +1,39 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "axpby/axpby.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
py::pos_only(),
|
||||
"alpha"_a,
|
||||
"beta"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
Scale and sum two vectors elementwise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
Inputs are upcasted to floats if needed
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
y (array): Input array.
|
||||
alpha (float): Scaling factor for ``x``.
|
||||
beta (float): Scaling factor for ``y``.
|
||||
|
||||
Returns:
|
||||
array: ``alpha * x + beta * y``
|
||||
)pbdoc");
|
||||
}
|
36
mlx/CMakeLists.txt
Normal file
36
mlx/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||
)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
|
||||
if (MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||
endif()
|
143
mlx/array.cpp
Normal file
143
mlx/array.cpp
Normal file
@ -0,0 +1,143 @@
|
||||
#include <functional>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
||||
std::vector<size_t> strides(shape.size());
|
||||
size_t cum_prod = 1;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
strides[i] = cum_prod;
|
||||
cum_prod *= shape[i];
|
||||
}
|
||||
return {cum_prod, strides};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
||||
auto cval = static_cast<complex64_t>(val);
|
||||
init(&cval);
|
||||
}
|
||||
|
||||
array::array(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::unique_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
shape,
|
||||
dtype,
|
||||
std::move(primitive),
|
||||
inputs)) {}
|
||||
|
||||
array::array(std::initializer_list<float> data)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
float32)) {
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||
set_data(data, deleter);
|
||||
}
|
||||
|
||||
void array::detach() {
|
||||
array_desc_->inputs.clear();
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
void array::eval(bool retain_graph /* = false */) {
|
||||
mlx::core::eval({*this}, retain_graph);
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->data_size = size();
|
||||
array_desc_->flags.contiguous = true;
|
||||
array_desc_->flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(shape().begin(), shape().end());
|
||||
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
|
||||
}
|
||||
|
||||
void array::set_data(
|
||||
allocator::Buffer buffer,
|
||||
size_t data_size,
|
||||
std::vector<size_t> strides,
|
||||
Flags flags,
|
||||
deleter_t d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->data_size = data_size;
|
||||
array_desc_->strides = std::move(strides);
|
||||
array_desc_->flags = flags;
|
||||
}
|
||||
|
||||
void array::copy_shared_buffer(
|
||||
const array& other,
|
||||
const std::vector<size_t>& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
array_desc_->data = other.array_desc_->data;
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
array_desc_->data_size = data_size;
|
||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||
array_desc_->data_ptr = static_cast<void*>(
|
||||
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||
}
|
||||
|
||||
void array::copy_shared_buffer(const array& other) {
|
||||
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
||||
: shape(shape), dtype(dtype) {
|
||||
std::tie(size, strides) = cum_prod(shape);
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
const std::vector<int>& shape,
|
||||
Dtype dtype,
|
||||
std::unique_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs)
|
||||
: shape(shape),
|
||||
dtype(dtype),
|
||||
primitive(std::move(primitive)),
|
||||
inputs(inputs) {
|
||||
std::tie(size, strides) = cum_prod(shape);
|
||||
for (auto& in : inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
}
|
||||
}
|
||||
|
||||
// Needed because the Primitive type used in array.h is incomplete and the
|
||||
// compiler needs to see the call to the desctructor after the type is complete.
|
||||
array::ArrayDesc::~ArrayDesc() = default;
|
||||
|
||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||
auto start = std::vector<int>(arr.ndim(), 0);
|
||||
auto end = arr.shape();
|
||||
auto shape = arr.shape();
|
||||
shape.erase(shape.begin());
|
||||
start[0] = idx;
|
||||
end[0] = idx + 1;
|
||||
return reshape(slice(arr, start, end), shape);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
323
mlx/backend/accelerate/softmax.cpp
Normal file
323
mlx/backend/accelerate/softmax.cpp
Normal file
@ -0,0 +1,323 @@
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* Compute exp(x) in an optimizer friendly way as follows:
|
||||
*
|
||||
* First change the problem to computing 2**y where y = x / ln(2).
|
||||
*
|
||||
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
|
||||
* `ipart` and y2 is fractional part. For the integer part we perform bit
|
||||
* shifting and for the fractional part we use a polynomial approximation.
|
||||
*
|
||||
* The algorithm and constants of the polynomial taken from
|
||||
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
|
||||
* from Cephes math library.
|
||||
*
|
||||
* Note: The implementation below is a general fast exp. There could be faster
|
||||
* implementations for numbers strictly < 0.
|
||||
*/
|
||||
inline simd_float16 simd_fast_exp(simd_float16 x) {
|
||||
x *= 1.442695; // multiply with log_2(e)
|
||||
simd_float16 ipart, fpart;
|
||||
simd_int16 epart;
|
||||
x = simd_clamp(x, -80, 80);
|
||||
ipart = simd::floor(x + 0.5);
|
||||
fpart = x - ipart;
|
||||
|
||||
x = 1.535336188319500e-4f;
|
||||
x = x * fpart + 1.339887440266574e-3f;
|
||||
x = x * fpart + 9.618437357674640e-3f;
|
||||
x = x * fpart + 5.550332471162809e-2f;
|
||||
x = x * fpart + 2.402264791363012e-1f;
|
||||
x = x * fpart + 6.931472028550421e-1f;
|
||||
x = x * fpart + 1.000000000000000f;
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
epart = (simd_int(ipart) + 127) << 23;
|
||||
|
||||
return (*(simd_float16*)&epart) * x;
|
||||
}
|
||||
|
||||
/**
|
||||
* The ARM neon equivalent of the fast exp above.
|
||||
*/
|
||||
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
|
||||
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
|
||||
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
|
||||
|
||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
|
||||
float16x8_t fpart = vsubq_f16(x, ipart);
|
||||
|
||||
x = vdupq_n_f16(1.535336188319500e-4f);
|
||||
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
int16x8_t epart = vcvtq_s16_f16(ipart);
|
||||
epart = vaddq_s16(epart, vdupq_n_s16(15));
|
||||
epart = vshlq_n_s16(epart, 10);
|
||||
|
||||
return vmulq_f16(vreinterpretq_f16_s16(epart), x);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of folding maximum for ARM neon. This should possibly be
|
||||
* refactored out of softmax.cpp at some point.
|
||||
*/
|
||||
inline float16_t neon_reduce_max(float16x8_t x) {
|
||||
float16x4_t y;
|
||||
y = vpmax_f16(vget_low_f16(x), vget_high_f16(x));
|
||||
y = vpmax_f16(y, y);
|
||||
y = vpmax_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of folding sum for ARM neon. This should possibly be
|
||||
* refactored out of softmax.cpp at some point.
|
||||
*/
|
||||
inline float16_t neon_reduce_add(float16x8_t x) {
|
||||
float16x4_t y;
|
||||
float16x4_t zero = vdup_n_f16(0);
|
||||
y = vpadd_f16(vget_low_f16(x), vget_high_f16(x));
|
||||
y = vpadd_f16(y, zero);
|
||||
y = vpadd_f16(y, zero);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct AccelerateSimdOps {
|
||||
VT init(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return *(VT*)a;
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
*(VT*)dst = x;
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
};
|
||||
|
||||
VT exp(VT x) {
|
||||
return simd_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return simd_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return simd_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct NeonFp16SimdOps {
|
||||
VT init(T a) {
|
||||
return vdupq_n_f16(a);
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return vld1q_f16(a);
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
vst1q_f16(dst, x);
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return vmaxq_f16(a, b);
|
||||
};
|
||||
|
||||
VT exp(VT x) {
|
||||
return neon_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return vaddq_f16(a, b);
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return vsubq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return vmulq_f16(a, b);
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return vmulq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return neon_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return neon_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
const T* current_in_ptr;
|
||||
T* current_out_ptr;
|
||||
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
T maximum = ops.reduce_max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, *current_in_ptr);
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
VT vnormalizer = ops.init(0.0);
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
|
||||
ops.store(current_out_ptr, vexp);
|
||||
*(VT*)current_out_ptr = vexp;
|
||||
vnormalizer = ops.add(vnormalizer, vexp);
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
T normalizer = ops.reduce_add(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
T _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = _exp;
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*current_out_ptr *= normalizer;
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto check_input = [](array x) {
|
||||
if (x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
array in = check_input(std::move(inputs[0]));
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case uint16:
|
||||
case uint32:
|
||||
case uint64:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::invalid_argument(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
|
||||
in, out);
|
||||
break;
|
||||
case float16:
|
||||
softmax<
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
eval(inputs, out);
|
||||
break;
|
||||
case complex64:
|
||||
eval(inputs, out);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
216
mlx/backend/common/binary.cpp
Normal file
216
mlx/backend/common/binary.cpp
Normal file
@ -0,0 +1,216 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
DefaultScalarVector<T, U, Op> opsv(op);
|
||||
DefaultVectorScalar<T, U, Op> opvs(op);
|
||||
DefaultVectorVector<T, U, Op> opvv(op);
|
||||
binary_op<T, U>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
comparison_op<bool, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
comparison_op<uint8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
comparison_op<uint16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
comparison_op<uint32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
comparison_op<uint64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
comparison_op<int8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
comparison_op<int16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
comparison_op<int32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
comparison_op<int64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float16:
|
||||
comparison_op<float16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float32:
|
||||
comparison_op<float, bool>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
comparison_op<bfloat16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
comparison_op<complex64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Add::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
||||
}
|
||||
|
||||
void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||
}
|
||||
|
||||
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (equal_nan_) {
|
||||
comparison_op(inputs[0], inputs[1], out, [](auto x, auto y) {
|
||||
return x == y || (std::isnan(x) && std::isnan(y));
|
||||
});
|
||||
} else {
|
||||
comparison_op(
|
||||
inputs[0], inputs[1], out, [](auto x, auto y) { return x == y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(
|
||||
inputs[0], inputs[1], out, [](auto x, auto y) { return x > y; });
|
||||
}
|
||||
|
||||
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(
|
||||
inputs[0], inputs[1], out, [](auto x, auto y) { return x >= y; });
|
||||
}
|
||||
|
||||
void Less::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(
|
||||
inputs[0], inputs[1], out, [](auto x, auto y) { return x < y; });
|
||||
}
|
||||
|
||||
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(
|
||||
inputs[0], inputs[1], out, [](auto x, auto y) { return x <= y; });
|
||||
}
|
||||
|
||||
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto op = [](auto x, auto y) {
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
auto maxval = (x > y) ? x : y;
|
||||
auto minval = (x > y) ? y : x;
|
||||
return (minval == -inf || maxval == inf)
|
||||
? maxval
|
||||
: static_cast<decltype(x)>(
|
||||
maxval + std::log1p(std::exp(minval - maxval)));
|
||||
};
|
||||
if (is_floating_point(out.dtype())) {
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, op);
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, op);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, op);
|
||||
} else {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[logaddexp] Cannot compute logaddexp for arrays with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
||||
}
|
||||
|
||||
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
||||
}
|
||||
|
||||
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, [](auto x, auto y) { return x * y; });
|
||||
}
|
||||
|
||||
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(
|
||||
inputs[0], inputs[1], out, [](auto x, auto y) { return x != y; });
|
||||
}
|
||||
|
||||
struct PowerFn {
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return std::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
if (exp < 0) {
|
||||
throw std::invalid_argument(
|
||||
"Integers cannot be raise to negative powers");
|
||||
}
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
void Power::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, PowerFn{});
|
||||
}
|
||||
|
||||
void Subtract::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, [](auto x, auto y) { return x - y; });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
554
mlx/backend/common/binary.h
Normal file
554
mlx/backend/common/binary.h
Normal file
@ -0,0 +1,554 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum BinaryOpType {
|
||||
ScalarScalar,
|
||||
ScalarVector,
|
||||
VectorScalar,
|
||||
VectorVector,
|
||||
General,
|
||||
};
|
||||
|
||||
BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
BinaryOpType bopt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1) {
|
||||
bopt = ScalarScalar;
|
||||
} else if (a.data_size() == 1 && b.flags().contiguous) {
|
||||
bopt = ScalarVector;
|
||||
} else if (b.data_size() == 1 && a.flags().contiguous) {
|
||||
bopt = VectorScalar;
|
||||
} else if (
|
||||
a.flags().row_contiguous && b.flags().row_contiguous ||
|
||||
a.flags().col_contiguous && b.flags().col_contiguous) {
|
||||
bopt = VectorVector;
|
||||
} else {
|
||||
bopt = General;
|
||||
}
|
||||
return bopt;
|
||||
}
|
||||
|
||||
void set_binary_op_output_data(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
BinaryOpType bopt) {
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case ScalarVector:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
break;
|
||||
case VectorScalar:
|
||||
case VectorVector:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
break;
|
||||
case General:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
struct UseDefaultBinaryOp {
|
||||
template <typename T, typename U>
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
// Should we throw? This should normally never be called.
|
||||
assert(false);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorScalar {
|
||||
Op op;
|
||||
|
||||
DefaultVectorScalar(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *b;
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, scalar);
|
||||
dst++;
|
||||
a++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultScalarVector {
|
||||
Op op;
|
||||
|
||||
DefaultScalarVector(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *a;
|
||||
while (size-- > 0) {
|
||||
*dst = op(scalar, *b);
|
||||
dst++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorVector {
|
||||
Op op;
|
||||
|
||||
DefaultVectorVector(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, *b);
|
||||
dst++;
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims1(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
int stride) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; i++) {
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
dst += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims2(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
int stride) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
dst += stride;
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
a_idx += a.strides()[2];
|
||||
b_idx += b.strides()[2];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
a_idx += a.strides()[3];
|
||||
b_idx += b.strides()[3];
|
||||
}
|
||||
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op) {
|
||||
switch (out.ndim()) {
|
||||
case 1:
|
||||
binary_op_dims1<T, U, Op>(a, b, out, op);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims2<T, U, Op>(a, b, out, op);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims3<T, U, Op>(a, b, out, op);
|
||||
return;
|
||||
case 4:
|
||||
binary_op_dims4<T, U, Op>(a, b, out, op);
|
||||
return;
|
||||
}
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < out.size(); i++) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
int dim,
|
||||
int stride) {
|
||||
// Number of dimensions to loop over for vectorized ops
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims1<T, U, Op>(a, b, out, op, stride);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims2<T, U, Op>(a, b, out, op, stride);
|
||||
return;
|
||||
}
|
||||
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < out.size(); i += stride) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||
dst += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename OpSV,
|
||||
typename OpVS,
|
||||
typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == ScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == ScalarVector) {
|
||||
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == VectorScalar) {
|
||||
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == VectorVector) {
|
||||
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto& strides = out.strides();
|
||||
auto leftmost_rc_dim = [&strides](const array& arr) {
|
||||
int d = arr.ndim() - 1;
|
||||
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a);
|
||||
auto b_rc_dim = leftmost_rc_dim(b);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const array& arr) {
|
||||
int d = arr.ndim() - 1;
|
||||
for (; d >= 0 && arr.strides()[d] == 0; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a);
|
||||
auto b_s_dim = leftmost_s_dim(b);
|
||||
|
||||
auto ndim = out.ndim();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
size_t stride;
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
stride = 1;
|
||||
bopt = General;
|
||||
dim = ndim;
|
||||
} else {
|
||||
stride = strides[dim - 1];
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case VectorVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
|
||||
break;
|
||||
case VectorScalar:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
|
||||
break;
|
||||
case ScalarVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||
// with template specializations and overloading. Would it be simpler?
|
||||
|
||||
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv and opvs were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opsv and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
opvs,
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvs and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
opsv,
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opvs was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
|
||||
}
|
||||
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// All ops provided
|
||||
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
DefaultScalarVector<T, T, Op> opsv(op);
|
||||
DefaultVectorScalar<T, T, Op> opvs(op);
|
||||
DefaultVectorVector<T, T, Op> opvv(op);
|
||||
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename... Ops>
|
||||
void binary(const array& a, const array& b, array& out, Ops... ops) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, ops...);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, ops...);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, ops...);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, ops...);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, ops...);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, out, ops...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
85
mlx/backend/common/fft.cpp
Normal file
85
mlx/backend/common/fft.cpp
Normal file
@ -0,0 +1,85 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
std::vector<std::ptrdiff_t> strides_in(
|
||||
in.strides().begin(), in.strides().end());
|
||||
for (auto& s : strides_in) {
|
||||
s *= in.itemsize();
|
||||
}
|
||||
std::vector<std::ptrdiff_t> strides_out(
|
||||
out.strides().begin(), out.strides().end());
|
||||
for (auto& s : strides_out) {
|
||||
s *= out.itemsize();
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
std::vector<size_t> shape;
|
||||
if (out.dtype() == float32) {
|
||||
shape.insert(shape.end(), out.shape().begin(), out.shape().end());
|
||||
} else {
|
||||
shape.insert(shape.end(), in.shape().begin(), in.shape().end());
|
||||
}
|
||||
|
||||
float scale = 1.0f;
|
||||
if (inverse_) {
|
||||
size_t nelem = std::accumulate(
|
||||
axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) {
|
||||
return x * shape[y];
|
||||
});
|
||||
scale /= nelem;
|
||||
}
|
||||
if (in.dtype() == complex64 && out.dtype() == complex64) {
|
||||
auto in_ptr =
|
||||
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||
auto out_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||
pocketfft::c2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
} else if (in.dtype() == float32 && out.dtype() == complex64) {
|
||||
auto in_ptr = in.data<float>();
|
||||
auto out_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||
pocketfft::r2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
} else if (in.dtype() == complex64 && out.dtype() == float32) {
|
||||
auto in_ptr =
|
||||
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||
auto out_ptr = out.data<float>();
|
||||
pocketfft::c2r(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[FFT] Received unexpected input and output type combination.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
98
mlx/backend/common/softmax.cpp
Normal file
98
mlx/backend/common/softmax.cpp
Normal file
@ -0,0 +1,98 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void softmax(const array& in, array& out) {
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
int N = in.shape().back();
|
||||
int M = in.data_size() / N;
|
||||
const T* current_in_ptr;
|
||||
T* current_out_ptr;
|
||||
|
||||
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
T maximum = *current_in_ptr;
|
||||
for (int j = 0; j < N; j++, current_in_ptr++) {
|
||||
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
T normalizer = 0;
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
|
||||
T expv = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += expv;
|
||||
*current_out_ptr = expv;
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
for (int j = 0; j < N; j++, current_out_ptr++) {
|
||||
*current_out_ptr *= normalizer;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto check_input = [](array x) {
|
||||
if (x.strides().back() == 1) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
array in = check_input(std::move(inputs[0]));
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case uint16:
|
||||
case uint32:
|
||||
case uint64:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::invalid_argument(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
softmax<float16_t>(in, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
softmax<bfloat16_t>(in, out);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
"[Softmax] Not yet implemented for complex64");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
200
mlx/backend/metal/allocator.cpp
Normal file
200
mlx/backend/metal/allocator.cpp
Normal file
@ -0,0 +1,200 @@
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <unistd.h>
|
||||
#include <cstdlib>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
return metal::allocator();
|
||||
}
|
||||
|
||||
void* Buffer::raw_ptr() {
|
||||
return static_cast<MTL::Buffer*>(ptr_)->contents();
|
||||
}
|
||||
|
||||
} // namespace allocator
|
||||
|
||||
namespace metal {
|
||||
|
||||
namespace {
|
||||
|
||||
BufferCache::BufferCache(MTL::Device* device)
|
||||
: device_(device),
|
||||
head_(nullptr),
|
||||
tail_(nullptr),
|
||||
pool_size_(0),
|
||||
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
|
||||
|
||||
BufferCache::~BufferCache() {
|
||||
clear();
|
||||
}
|
||||
|
||||
void BufferCache::clear() {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
if (holder->buf)
|
||||
holder->buf->release();
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
|
||||
// Find the closest buffer in pool
|
||||
MTL::Buffer* pbuf = nullptr;
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
|
||||
// Make sure we use > 50% of the available memory
|
||||
while (!pbuf && it != buffer_pool_.end() && it->first < 2 * size) {
|
||||
// Collect from the cache
|
||||
pbuf = it->second->buf;
|
||||
// Remove from cache
|
||||
remove_from_list(it->second);
|
||||
delete it->second;
|
||||
it = buffer_pool_.erase(it);
|
||||
}
|
||||
|
||||
if (pbuf) {
|
||||
pool_size_ -= pbuf->length();
|
||||
}
|
||||
|
||||
return pbuf;
|
||||
}
|
||||
|
||||
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
|
||||
// Add to cache
|
||||
if (buf) {
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
add_at_head(bh);
|
||||
pool_size_ += buf->length();
|
||||
buffer_pool_.insert({buf->length(), bh});
|
||||
}
|
||||
}
|
||||
|
||||
size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
min_bytes_to_free += device_->currentAllocatedSize() - gc_limit_;
|
||||
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
size_t old_pool_size = pool_size_;
|
||||
clear();
|
||||
return old_pool_size;
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
if (tail_->buf) {
|
||||
total_bytes_freed += tail_->buf->length();
|
||||
tail_->buf->release();
|
||||
tail_->buf = nullptr;
|
||||
}
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return total_bytes_freed;
|
||||
}
|
||||
}
|
||||
|
||||
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
|
||||
if (!to_add)
|
||||
return;
|
||||
|
||||
if (!head_) {
|
||||
head_ = to_add;
|
||||
tail_ = to_add;
|
||||
} else {
|
||||
head_->prev = to_add;
|
||||
to_add->next = head_;
|
||||
head_ = to_add;
|
||||
}
|
||||
}
|
||||
|
||||
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
if (!to_remove)
|
||||
return;
|
||||
|
||||
// If in the middle
|
||||
if (to_remove->prev && to_remove->next) {
|
||||
to_remove->prev->next = to_remove->next;
|
||||
to_remove->next->prev = to_remove->prev;
|
||||
} else if (to_remove->prev && to_remove == tail_) { // If tail
|
||||
tail_ = to_remove->prev;
|
||||
tail_->next = nullptr;
|
||||
} else if (to_remove == head_ && to_remove->next) { // If head
|
||||
head_ = to_remove->next;
|
||||
head_->prev = nullptr;
|
||||
} else if (to_remove == head_ && to_remove == tail_) { // If only element
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
to_remove->prev = nullptr;
|
||||
to_remove->next = nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
buffer_cache_(device_),
|
||||
peak_allocated_size_(0),
|
||||
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()) {}
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size) {
|
||||
// Align up memory
|
||||
if (size > vm_page_size) {
|
||||
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
|
||||
}
|
||||
|
||||
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
|
||||
// Prepare to allocate new memory as needed
|
||||
if (!buf) {
|
||||
// If we are under very high memoory pressure, we don't allocate further
|
||||
if (device_->currentAllocatedSize() >= block_limit_) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
// If we are still under memory pressure, try cleaning cache
|
||||
if (buffer_cache_.can_garbage_collect()) {
|
||||
buffer_cache_.release_cached_buffers(size);
|
||||
}
|
||||
|
||||
// Allocate new buffer if needed
|
||||
size_t res_opt = MTL::ResourceStorageModeShared;
|
||||
res_opt |= MTL::ResourceHazardTrackingModeTracked;
|
||||
buf = device_->newBuffer(size, res_opt);
|
||||
}
|
||||
|
||||
peak_allocated_size_ =
|
||||
std::max(peak_allocated_size_, device_->currentAllocatedSize());
|
||||
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
}
|
||||
|
||||
MetalAllocator& allocator() {
|
||||
static MetalAllocator allocator_;
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
} // namespace mlx::core
|
76
mlx/backend/metal/allocator.h
Normal file
76
mlx/backend/metal/allocator.h
Normal file
@ -0,0 +1,76 @@
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
namespace {
|
||||
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(MTL::Device* device);
|
||||
~BufferCache();
|
||||
void clear();
|
||||
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
void recycle_to_cache(MTL::Buffer* buf);
|
||||
size_t release_cached_buffers(size_t min_bytes_to_free);
|
||||
|
||||
bool can_garbage_collect() {
|
||||
return pool_size_ > 0 && device_->currentAllocatedSize() > gc_limit_;
|
||||
}
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
public:
|
||||
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
|
||||
|
||||
BufferHolder* prev;
|
||||
BufferHolder* next;
|
||||
MTL::Buffer* buf;
|
||||
};
|
||||
|
||||
void add_at_head(BufferHolder* to_add);
|
||||
void remove_from_list(BufferHolder* to_remove);
|
||||
|
||||
MTL::Device* device_;
|
||||
std::mutex cache_mutex_;
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_;
|
||||
BufferHolder* tail_;
|
||||
size_t pool_size_;
|
||||
size_t gc_limit_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class MetalAllocator : public allocator::Allocator {
|
||||
/** Allocator for Metal GPUs. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
|
||||
private:
|
||||
MTL::Device* device_;
|
||||
MetalAllocator();
|
||||
friend MetalAllocator& allocator();
|
||||
|
||||
// Caching allocator
|
||||
BufferCache buffer_cache_;
|
||||
|
||||
// Allocation stats
|
||||
size_t peak_allocated_size_;
|
||||
size_t block_limit_;
|
||||
};
|
||||
|
||||
MetalAllocator& allocator();
|
||||
|
||||
} // namespace mlx::core::metal
|
555
mlx/backend/metal/conv.cpp
Normal file
555
mlx/backend/metal/conv.cpp
Normal file
@ -0,0 +1,555 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void explicit_gemm_conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<1>& conv_params) {
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {
|
||||
conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.C};
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {
|
||||
conv_params.N, conv_params.oS[0], conv_params.wS[0], conv_params.C};
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * conv_params.str[0],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2]};
|
||||
auto flags = in_padded.flags();
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {
|
||||
conv_params.N * conv_params.oS[0], conv_params.wS[0] * conv_params.C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
|
||||
// Peform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
mlx_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_strided,
|
||||
/*b = */ wt,
|
||||
/*c = */ out,
|
||||
/*M = */ strided_reshape[0],
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ strided_reshape[1],
|
||||
/*batch_size_out = */ 1,
|
||||
/*a_cols = */ strided_reshape[1],
|
||||
/*b_cols = */ strided_reshape[1],
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
void conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(2),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1)},
|
||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||
/* const int pad[NDIM] = */ {padding[0]},
|
||||
/* const int dil[NDIM] = */ {wt_dilation[0]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||
};
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
if (wt_dilation[0] == 1) {
|
||||
explicit_gemm_conv_1D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to fallback conv
|
||||
else {
|
||||
throw std::invalid_argument("[conv_1D_gpu] Dilation needs to be 1.");
|
||||
}
|
||||
}
|
||||
|
||||
void slow_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
int bm = 16, bn = 8;
|
||||
int tm = 4, tn = 4;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
|
||||
<< "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
|
||||
|
||||
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
|
||||
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
|
||||
size_t grid_dim_z = conv_params.N;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bm, bn, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
|
||||
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
int implicit_N = conv_params.O;
|
||||
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
||||
|
||||
size_t grid_dim_x = (implicit_N + bn - 1) / bn;
|
||||
size_t grid_dim_y = (implicit_M + bm - 1) / bm;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {
|
||||
conv_params.N,
|
||||
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||
conv_params.C};
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
||||
conv_params.pad[1] * in_padded.strides()[2];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {
|
||||
conv_params.N,
|
||||
conv_params.oS[0],
|
||||
conv_params.oS[1],
|
||||
conv_params.wS[0],
|
||||
conv_params.wS[1],
|
||||
conv_params.C};
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * conv_params.str[0],
|
||||
in_padded.strides()[2] * conv_params.str[1],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2],
|
||||
in_padded.strides()[3]};
|
||||
auto flags = in_padded.flags();
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {
|
||||
conv_params.N * conv_params.oS[0] * conv_params.oS[1],
|
||||
conv_params.wS[0] * conv_params.wS[1] * conv_params.C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
|
||||
// Peform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
mlx_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_strided,
|
||||
/*b = */ wt,
|
||||
/*c = */ out,
|
||||
/*M = */ strided_reshape[0],
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ strided_reshape[1],
|
||||
/*batch_size_out = */ 1,
|
||||
/*a_cols = */ strided_reshape[1],
|
||||
/*b_cols = */ strided_reshape[1],
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/*copies = */ copies);
|
||||
}
|
||||
|
||||
void winograd_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params,
|
||||
std::vector<array>& copies_w) {
|
||||
std::vector<int> padded_shape = {
|
||||
conv_params.N,
|
||||
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||
conv_params.C};
|
||||
|
||||
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
||||
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
||||
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
array zero_arr = array(0, in.dtype());
|
||||
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
||||
conv_params.pad[1] * in_padded.strides()[2];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||
|
||||
copies_w.push_back(in_padded_slice);
|
||||
copies_w.push_back(in_padded);
|
||||
copies_w.push_back(zero_arr);
|
||||
|
||||
MLXConvParams<2> conv_params_updated{
|
||||
/* const int N = */ in_padded.shape(0),
|
||||
/* const int C = */ in_padded.shape(3),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int str[NDIM] = */ {1, 1},
|
||||
/* const int pad[NDIM] = */ {0, 0},
|
||||
/* const int dil[NDIM] = */ {1, 1},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in_padded.strides()[0],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2],
|
||||
in_padded.strides()[3]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
};
|
||||
|
||||
int O_c = conv_params.O;
|
||||
int C_c = conv_params.C;
|
||||
|
||||
int N_tiles_n = conv_params.N;
|
||||
int N_tiles_h = (conv_params.oS[0] + 5) / 6;
|
||||
int N_tiles_w = (conv_params.oS[1] + 5) / 6;
|
||||
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
|
||||
|
||||
// Do filter transform
|
||||
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
|
||||
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
||||
copies_w.push_back(filt_wg);
|
||||
{
|
||||
int bc = 32;
|
||||
int bo = 4;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, wt, 0);
|
||||
set_array_buffer(compute_encoder, filt_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(&C_c, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&O_c, sizeof(int), 3);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do input transform
|
||||
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
|
||||
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
||||
copies_w.push_back(inp_wg);
|
||||
{
|
||||
int bc = 32;
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in_padded, 0);
|
||||
set_array_buffer(compute_encoder, inp_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do batched gemm
|
||||
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
|
||||
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
||||
copies_w.push_back(out_wg);
|
||||
{
|
||||
std::vector<array> empty_copies;
|
||||
mlx_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ inp_wg,
|
||||
/*b = */ filt_wg,
|
||||
/*c = */ out_wg,
|
||||
/*M = */ N_tiles,
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ conv_params.C,
|
||||
/*batch_size_out = */ 8 * 8,
|
||||
/*a_cols = */ conv_params.C,
|
||||
/*b_cols = */ conv_params.O,
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ false,
|
||||
/*copies = */ empty_copies);
|
||||
}
|
||||
|
||||
// Do output transform
|
||||
{
|
||||
int bc = 32;
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, out_wg, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<2> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
/* const int C = */ in.shape(3),
|
||||
/* const int O = */ wt.shape(0),
|
||||
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
|
||||
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1]},
|
||||
/* const int dil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
};
|
||||
|
||||
// Direct to winograd conv
|
||||
if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
|
||||
conv_params.C >= 64 && conv_params.O >= 64 && conv_params.wS[0] == 3 &&
|
||||
conv_params.wS[1] == 3 && conv_params.str[0] == 1 &&
|
||||
conv_params.str[1] == 1 && conv_params.dil[0] == 1 &&
|
||||
conv_params.dil[1] == 1) {
|
||||
winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
else if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0) {
|
||||
implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
else if (wt_dilation[0] == 1 && wt_dilation[1] == 1) {
|
||||
explicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to fallback conv
|
||||
else {
|
||||
slow_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Ensure contiguity
|
||||
std::vector<array> copies;
|
||||
auto in = inputs[0];
|
||||
auto wt = inputs[1];
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
in = arr_copy;
|
||||
}
|
||||
if (!wt.flags().row_contiguous) {
|
||||
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
|
||||
copy_gpu(wt, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
wt = arr_copy;
|
||||
}
|
||||
|
||||
// 2D conv
|
||||
if (out.ndim() == 4) {
|
||||
conv_2D_gpu(
|
||||
s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_, copies);
|
||||
}
|
||||
// 1D conv
|
||||
else if (out.ndim() == 3) {
|
||||
conv_1D_gpu(s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||
}
|
||||
// Throw error
|
||||
else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval_gpu] Only supports 1D or 2D convolutions.");
|
||||
}
|
||||
|
||||
// Clear copies
|
||||
if (copies.size() > 0) {
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
16
mlx/backend/metal/copy.h
Normal file
16
mlx/backend/metal/copy.h
Normal file
@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype);
|
||||
void copy_gpu_inplace(
|
||||
const array& src,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
30
mlx/backend/metal/kernels/arange.metal
Normal file
30
mlx/backend/metal/kernels/arange.metal
Normal file
@ -0,0 +1,30 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void arange(
|
||||
constant const T& start,
|
||||
constant const T& step,
|
||||
device T* out,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = start + index * step;
|
||||
}
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
template [[host_name("arange" #tname)]] \
|
||||
[[kernel]] void arange<type>( \
|
||||
constant const type& start, \
|
||||
constant const type& step, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
instantiate_arange(uint16, uint16_t)
|
||||
instantiate_arange(uint32, uint32_t)
|
||||
instantiate_arange(uint64, uint64_t)
|
||||
instantiate_arange(int8, int8_t)
|
||||
instantiate_arange(int16, int16_t)
|
||||
instantiate_arange(int32, int32_t)
|
||||
instantiate_arange(int64, int64_t)
|
||||
instantiate_arange(float16, half)
|
||||
instantiate_arange(float32, float)
|
||||
instantiate_arange(bfloat16, bfloat16_t)
|
208
mlx/backend/metal/kernels/arg_reduce.metal
Normal file
208
mlx/backend/metal/kernels/arg_reduce.metal
Normal file
@ -0,0 +1,208 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename U>
|
||||
struct IndexValPair {
|
||||
uint32_t index;
|
||||
U val;
|
||||
|
||||
IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct ArgMin {
|
||||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||
if (best.val > current.val || (best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||
for (int i=0; i<N; i++) {
|
||||
if (vals[i] < best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset+i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct ArgMax {
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||
if (best.val < current.val || (best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||
for (int i=0; i<N; i++) {
|
||||
if (vals[i] > best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset+i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
bool simd_shuffle_down(bool data, uint16_t delta) {
|
||||
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
||||
}
|
||||
|
||||
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
||||
return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
||||
return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||
return IndexValPair<U>(
|
||||
simd_shuffle_down(data.index, delta),
|
||||
simd_shuffle_down(data.val, delta)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename Op, int N_READS>
|
||||
[[kernel]] void arg_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
device uint32_t *out [[buffer(1)]],
|
||||
const device int *shape [[buffer(2)]],
|
||||
const device size_t *in_strides [[buffer(3)]],
|
||||
const device size_t *out_strides [[buffer(4)]],
|
||||
const device size_t& ndim [[buffer(5)]],
|
||||
const device size_t& axis_stride [[buffer(6)]],
|
||||
const device size_t& axis_size [[buffer(7)]],
|
||||
threadgroup IndexValPair<T> *local_data [[threadgroup(0)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
// Shapes and strides *do not* contain the reduction axis. The reduction size
|
||||
// and stride are provided in axis_stride and axis_size.
|
||||
//
|
||||
// Note: in shape == out shape with this convention.
|
||||
//
|
||||
// The sketch of the kernel is as follows.
|
||||
// 1. Launch prod(shape) * thread_group_size threads.
|
||||
// 2. Loop ceildiv(axis_size / lsize) times
|
||||
// 3. Read input values
|
||||
// 4. Reduce among them and go to 3
|
||||
// 4. Reduce in each simd_group
|
||||
// 6. Write in the thread local memory
|
||||
// 6. Reduce them accross thread group
|
||||
// 7. Write the output without need for atomic
|
||||
Op op;
|
||||
|
||||
// Compute the input/output index. There is one beginning and one output for
|
||||
// the whole threadgroup.
|
||||
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
|
||||
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
|
||||
|
||||
IndexValPair<T> best(0, Op::init);
|
||||
|
||||
// Loop over the reduction axis in lsize*N_READS buckets
|
||||
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
||||
// Read the current value
|
||||
uint32_t current_index = r*lsize*N_READS + lid*N_READS;
|
||||
uint32_t offset = current_index;
|
||||
const device T * current_in = in + in_idx + current_index * axis_stride;
|
||||
T vals[N_READS];
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
|
||||
current_index++;
|
||||
current_in += axis_stride;
|
||||
}
|
||||
best = op.template reduce_many<N_READS>(best, vals, offset);
|
||||
}
|
||||
// At this point we have reduced the axis into thread group best values so we
|
||||
// need to reduce across the thread group.
|
||||
|
||||
// First per simd reduction.
|
||||
for (uint offset=simd_size/2; offset>0; offset/=2) {
|
||||
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
||||
best = op.reduce(best, neighbor);
|
||||
}
|
||||
|
||||
// Write to the threadgroup memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_data[simd_group_id] = best;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_group_id != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Read the appropriate value from local data and perform one simd reduction
|
||||
uint simd_groups = ceildiv(lsize, simd_size);
|
||||
if (simd_lane_id < simd_groups) {
|
||||
best = local_data[simd_lane_id];
|
||||
}
|
||||
for (uint offset=simd_size/2; offset>0; offset/=2) {
|
||||
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
||||
best = op.reduce(best, neighbor);
|
||||
}
|
||||
|
||||
// Finally write the output
|
||||
if (lid == 0) {
|
||||
out[out_idx] = best.index;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_arg_reduce_helper(name, itype, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device uint32_t * out [[buffer(1)]], \
|
||||
const device int *shape [[buffer(2)]], \
|
||||
const device size_t *in_strides [[buffer(3)]], \
|
||||
const device size_t *out_strides [[buffer(4)]], \
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
const device size_t& axis_stride [[buffer(6)]], \
|
||||
const device size_t& axis_size [[buffer(7)]], \
|
||||
threadgroup IndexValPair<itype> *local_data [[threadgroup(0)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_arg_reduce(name, itype) \
|
||||
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
|
||||
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
|
||||
|
||||
instantiate_arg_reduce(bool_, bool)
|
||||
instantiate_arg_reduce(uint8, uint8_t)
|
||||
instantiate_arg_reduce(uint16, uint16_t)
|
||||
instantiate_arg_reduce(uint32, uint32_t)
|
||||
instantiate_arg_reduce(uint64, uint64_t)
|
||||
instantiate_arg_reduce(int8, int8_t)
|
||||
instantiate_arg_reduce(int16, int16_t)
|
||||
instantiate_arg_reduce(int32, int32_t)
|
||||
instantiate_arg_reduce(int64, int64_t)
|
||||
instantiate_arg_reduce(float16, half)
|
||||
instantiate_arg_reduce(float32, float)
|
||||
instantiate_arg_reduce(bfloat16, bfloat16_t)
|
392
mlx/backend/metal/kernels/bf16_math.h
Normal file
392
mlx/backend/metal/kernels/bf16_math.h
Normal file
@ -0,0 +1,392 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal math for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
|
||||
Following the Metal Shading Language Specification (Metal 3.1)
|
||||
|
||||
"bfloat is an extended itypeing point type that only allows implicit conversion
|
||||
to a type of greater itypeing point rank. While bfloat can be implicitly
|
||||
converted to itype, it cannot be implicitly converted to half, and neither
|
||||
itype nor half can be implicitly converted to bfloat."
|
||||
|
||||
Further, as far as I can tell, the stdlib math/simd functions are not defined
|
||||
for bfloat and calling with an argument of type bfloat will result in that
|
||||
argument getting implicitly converted to itype which then returns an output
|
||||
that is (likely) a itype which cannot be implicitly converted into a bfloat
|
||||
|
||||
This leads to situations where
|
||||
bfloat a = 5.0bf;
|
||||
bfloat b = metal::abs(a); // this will throw an error since abs return itype
|
||||
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
|
||||
|
||||
For the moment, I will be adding overloaded instantiations of the math
|
||||
functions to accordingly automatically handle the casting
|
||||
|
||||
*/
|
||||
|
||||
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
|
||||
\
|
||||
METAL_FUNC otype abs(itype x) { \
|
||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype acos(itype x) { \
|
||||
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype acosh(itype x) { \
|
||||
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype asin(itype x) { \
|
||||
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype asinh(itype x) { \
|
||||
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atan(itype y_over_x) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atan2(itype y, itype x) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype atanh(itype x) { \
|
||||
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype ceil(itype x) { \
|
||||
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cos(itype x) { \
|
||||
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cosh(itype x) { \
|
||||
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype cospi(itype x) { \
|
||||
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype divide(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp(itype x) { \
|
||||
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp10(itype x) { \
|
||||
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype exp2(itype x) { \
|
||||
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fabs(itype x) { \
|
||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fdim(itype x, itype y) { \
|
||||
ctype t = static_cast<ctype>(x - y); \
|
||||
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
|
||||
} \
|
||||
METAL_FUNC otype floor(itype x) { \
|
||||
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fma(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fma( \
|
||||
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
|
||||
} \
|
||||
METAL_FUNC otype fmax(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmax3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmedian3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmin(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmin3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fmod(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype fract(itype x) { \
|
||||
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype frexp(itype x, thread int& exp) { \
|
||||
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
|
||||
} \
|
||||
METAL_FUNC otype ldexp(itype x, int k) { \
|
||||
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log(itype x) { \
|
||||
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log10(itype x) { \
|
||||
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype log2(itype x) { \
|
||||
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype max(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype max3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmax3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype median3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmedian3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype min(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype min3(itype x, itype y, itype z) { \
|
||||
return static_cast<otype>(__metal_fmin3( \
|
||||
static_cast<ctype>(x), \
|
||||
static_cast<ctype>(y), \
|
||||
static_cast<ctype>(z), \
|
||||
mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype nextafter(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
|
||||
} \
|
||||
METAL_FUNC otype pow(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype powr(itype x, itype y) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype rint(itype x) { \
|
||||
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype round(itype x) { \
|
||||
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype rsqrt(itype x) { \
|
||||
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sin(itype x) { \
|
||||
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sinh(itype x) { \
|
||||
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sinpi(itype x) { \
|
||||
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype sqrt(itype x) { \
|
||||
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tan(itype x) { \
|
||||
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tanh(itype x) { \
|
||||
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype tanpi(itype x) { \
|
||||
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
|
||||
} \
|
||||
METAL_FUNC otype trunc(itype x) { \
|
||||
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
|
||||
}
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_math_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
__METAL_MAYBE_FAST_MATH__);
|
||||
|
||||
namespace fast {
|
||||
|
||||
instantiate_metal_math_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
__METAL_FAST_MATH__);
|
||||
|
||||
} // namespace fast
|
||||
|
||||
namespace precise {
|
||||
|
||||
instantiate_metal_math_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
__METAL_PRECISE_MATH__);
|
||||
|
||||
} // namespace precise
|
||||
|
||||
} // namespace metal
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Metal simd for bfloat16
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_metal_simd_comm_funcs( \
|
||||
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
|
||||
\
|
||||
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
||||
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
||||
itype data, itype filling_data, ushort delta) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||
itype_to_ctype(data), \
|
||||
itype_to_ctype(filling_data), \
|
||||
delta, \
|
||||
__metal_get_simdgroup_size(ushort()))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
||||
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
||||
itype data, itype filling_data, ushort delta) { \
|
||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||
itype_to_ctype(data), \
|
||||
itype_to_ctype(filling_data), \
|
||||
delta, \
|
||||
__metal_get_simdgroup_size(ushort()))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
|
||||
return ctype_to_otype( \
|
||||
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
|
||||
}
|
||||
|
||||
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
|
||||
\
|
||||
METAL_FUNC otype simd_max(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_min(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
|
||||
return static_cast<otype>( \
|
||||
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_product(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_sum(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
|
||||
} \
|
||||
\
|
||||
METAL_FUNC otype simd_xor(itype data) { \
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
||||
#else
|
||||
|
||||
#define bfloat16_to_uint16(x) x.bits_
|
||||
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
||||
|
||||
#endif
|
||||
|
||||
namespace metal {
|
||||
|
||||
instantiate_metal_simd_comm_funcs(
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
uint16_t,
|
||||
bfloat16_to_uint16,
|
||||
uint16_to_bfloat16);
|
||||
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
||||
|
||||
} // namespace metal
|
553
mlx/backend/metal/kernels/conv.metal
Normal file
553
mlx/backend/metal/kernels/conv.metal
Normal file
@ -0,0 +1,553 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/gemm/conv.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Slow and naive kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const int BC = 16>
|
||||
[[kernel]] void naive_conv_2d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* wt [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)simd_gid;
|
||||
(void)simd_lid;
|
||||
|
||||
out += tid.z * params.out_strides[0];
|
||||
in += tid.z * params.in_strides[0];
|
||||
|
||||
int out_o = tid.y * BN * TN + lid.y * TN;
|
||||
int out_hw = tid.x * BM * TM + lid.x * TM;
|
||||
|
||||
int out_h[TM];
|
||||
int out_w[TN];
|
||||
|
||||
for(int m = 0; m < TM; ++m) {
|
||||
int mm = (out_hw + m);
|
||||
out_h[m] = mm / params.oS[1];
|
||||
out_w[m] = mm % params.oS[1];
|
||||
}
|
||||
|
||||
|
||||
T in_local[TM];
|
||||
T wt_local[TN];
|
||||
T out_local[TM * TN] = {T(0)};
|
||||
|
||||
for(int h = 0; h < params.wS[0]; ++h) {
|
||||
for(int w = 0; w < params.wS[1]; ++w) {
|
||||
for(int c = 0; c < params.C; ++c) {
|
||||
|
||||
// Local in
|
||||
for(int m = 0; m < TM; m++) {
|
||||
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.dil[0];
|
||||
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.dil[1];
|
||||
|
||||
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
||||
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
|
||||
}
|
||||
|
||||
// Load weight
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
int o = out_o + n;
|
||||
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] +
|
||||
h * params.wt_strides[1] +
|
||||
w * params.wt_strides[2] + c] : T(0);
|
||||
}
|
||||
|
||||
// Accumulate
|
||||
for(int m = 0; m < TM; ++m) {
|
||||
for(int n = 0; n < TN; ++n) {
|
||||
out_local[m * TN + n] += in_local[m] * wt_local[n];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int m = 0; m < TM; ++m) {
|
||||
for(int n = 0; n < TN; ++n) {
|
||||
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O)
|
||||
out[out_h[m] * params.out_strides[1] +
|
||||
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Instantiations
|
||||
|
||||
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
const device itype* wt [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_naive_conv_2d_blocks(name, itype) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
|
||||
|
||||
instantiate_naive_conv_2d_blocks(float32, float);
|
||||
instantiate_naive_conv_2d_blocks(float16, half);
|
||||
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Implicit gemm kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* wt [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemm_kernel = Conv2DImplicitGEMMKernel<T, BM, BN, BK, WM, WN, /*transpose_a*/ false, /*transpose_b*/ true>;
|
||||
|
||||
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
|
||||
|
||||
gemm_kernel::run(
|
||||
in, wt, out,
|
||||
params, tgp_memory,
|
||||
tid, lid, simd_gid, simd_lid
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
const device itype* wt [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
instantiate_implicit_conv_2d(name, itype, 32, 32, 32, 2, 2) \
|
||||
instantiate_implicit_conv_2d(name, itype, 32, 32, 16, 2, 2) \
|
||||
instantiate_implicit_conv_2d(name, itype, 64, 64, 16, 2, 2)
|
||||
|
||||
instantiate_implicit_2d_blocks(float32, float);
|
||||
instantiate_implicit_2d_blocks(float16, half);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Winograd kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int M, int R, int S>
|
||||
struct WinogradTransforms {
|
||||
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WinogradTransforms<6, 3, 8> {
|
||||
MLX_MTL_CONST int OUT_TILE_SIZE = 6;
|
||||
MLX_MTL_CONST int FILTER_SIZE = 3;
|
||||
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
|
||||
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
|
||||
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||
{ 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
|
||||
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
|
||||
{ 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
|
||||
{ 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
|
||||
{ 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
|
||||
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
|
||||
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||
};
|
||||
|
||||
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||
{ 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
|
||||
{ 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
|
||||
{ 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
|
||||
{ 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
|
||||
{ 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
|
||||
{ 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
|
||||
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||
};
|
||||
|
||||
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||
{ 1.00, 0.00, 0.00},
|
||||
{ -2.0/9.00, -2.0/9.00, -2.0/9.00},
|
||||
{ -2.0/9.00, 2.0/9.00, -2.0/9.00},
|
||||
{ 1.0/90.0, 1.0/45.0, 2.0/45.0},
|
||||
{ 1.0/90.0, -1.0/45.0, 2.0/45.0},
|
||||
{ 32.0/45.0, 16.0/45.0, 8.0/45.0},
|
||||
{ 32.0/45.0, -16.0/45.0, 8.0/45.0},
|
||||
{ 0.00, 0.00, 1.00},
|
||||
};
|
||||
};
|
||||
|
||||
constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
|
||||
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
||||
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
||||
|
||||
template <typename T,
|
||||
int BC = 32,
|
||||
int BO = 4,
|
||||
int M = 6,
|
||||
int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
|
||||
const device T* wt_in [[buffer(0)]],
|
||||
device T* wt_out [[buffer(1)]],
|
||||
const constant int& C [[buffer(2)]],
|
||||
const constant int& O [[buffer(3)]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using WGT = WinogradTransforms<M, R, 8>;
|
||||
|
||||
// Get lane position in simdgroup
|
||||
const short qid = simd_lane_id / 4;
|
||||
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Initialize G matrix
|
||||
simdgroup_matrix<T, 8, 8> G;
|
||||
G.thread_elements()[0] = WGT::wt_transform[sm][sn];
|
||||
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
|
||||
|
||||
// Initialize Gt matrix
|
||||
simdgroup_matrix<T, 8, 8> Gt;
|
||||
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
|
||||
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
|
||||
|
||||
// Move to the correct output filter
|
||||
size_t ko = BO * tid + simd_group_id;
|
||||
wt_in += ko * R * R * C;
|
||||
|
||||
// wt_out is stored transposed (A x A x C x O)
|
||||
short ohw_0 = sm * 8 + sn;
|
||||
short ohw_1 = sm * 8 + sn + 1;
|
||||
device T* wt_out_0 = wt_out + ohw_0 * C * O + ko;
|
||||
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
|
||||
|
||||
// Prepare shared memory
|
||||
threadgroup T Ws[BO][R][R][BC];
|
||||
|
||||
// Loop over C
|
||||
for(int bc = 0; bc < C; bc += BC) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Read into shared memory
|
||||
for(int kh = 0; kh < R; ++kh) {
|
||||
for(int kw = 0; kw < R; ++kw) {
|
||||
for(int kc = simd_lane_id; kc < BC; kc += 32) {
|
||||
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for(int c = 0; c < BC; ++c) {
|
||||
simdgroup_matrix<T, 8, 8> g;
|
||||
g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
||||
g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
||||
|
||||
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
|
||||
wt_out_0[c * O] = g_out.thread_elements()[0];
|
||||
wt_out_1[c * O] = g_out.thread_elements()[1];
|
||||
}
|
||||
|
||||
wt_in += BC;
|
||||
wt_out_0 += BC * O;
|
||||
wt_out_1 += BC * O;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
||||
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\
|
||||
[[kernel]] void winograd_conv_2d_weight_transform<itype, bc>(\
|
||||
const device itype* wt_in [[buffer(0)]],\
|
||||
device itype* wt_out [[buffer(1)]],\
|
||||
const constant int& C [[buffer(2)]],\
|
||||
const constant int& O [[buffer(3)]],\
|
||||
uint tid [[threadgroup_position_in_grid]],\
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||
|
||||
template <typename T,
|
||||
int BC,
|
||||
int WM,
|
||||
int WN,
|
||||
int M = 6,
|
||||
int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform(
|
||||
const device T* inp_in [[buffer(0)]],
|
||||
device T* inp_out [[buffer(1)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)lid;
|
||||
|
||||
using WGT = WinogradTransforms<M, R, 8>;
|
||||
constexpr int A = WGT::IN_TILE_SIZE;
|
||||
constexpr int N_SIMD_GROUPS = WM * WN;
|
||||
|
||||
// Get lane position in simdgroup
|
||||
const short qid = simd_lane_id / 4;
|
||||
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Initialize B matrix
|
||||
simdgroup_matrix<T, 8, 8> B;
|
||||
B.thread_elements()[0] = WGT::in_transform[sm][sn];
|
||||
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
|
||||
|
||||
// Initialize Bt matrix
|
||||
simdgroup_matrix<T, 8, 8> Bt;
|
||||
Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
|
||||
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
|
||||
|
||||
// Resolve input tile
|
||||
constexpr int TH = (A / WM);
|
||||
constexpr int TW = (A / WN);
|
||||
int kh = TH * (simd_group_id / WN);
|
||||
int kw = TW * (simd_group_id % WN);
|
||||
int bh = M * tid.y + kh;
|
||||
int bw = M * tid.x + kw;
|
||||
|
||||
// Move to the correct input tile
|
||||
inp_in += tid.z * params.in_strides[0]
|
||||
+ bh * params.in_strides[1]
|
||||
+ bw * params.in_strides[2];
|
||||
|
||||
// Pre compute strides
|
||||
int jump_in[TH][TW];
|
||||
|
||||
for(int h = 0; h < TH; h++) {
|
||||
for(int w = 0; w < TW; w++) {
|
||||
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
|
||||
}
|
||||
}
|
||||
|
||||
// inp_out is stored interleaved (A x A x tiles x C)
|
||||
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
||||
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||
size_t ohw_0 = sm * 8 + sn;
|
||||
size_t ohw_1 = sm * 8 + sn + 1;
|
||||
device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
|
||||
device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
|
||||
|
||||
// Prepare shared memory
|
||||
threadgroup T Is[A][A][BC];
|
||||
|
||||
// Loop over C
|
||||
for(int bc = 0; bc < params.C; bc += BC) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Read into shared memory
|
||||
for(int h = 0; h < TH; h++) {
|
||||
for(int w = 0; w < TW; w++) {
|
||||
const device T* in_ptr = inp_in + jump_in[h][w];
|
||||
for(int c = simd_lane_id; c < BC; c += 32) {
|
||||
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for(int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
|
||||
simdgroup_matrix<T, 8, 8> I;
|
||||
I.thread_elements()[0] = Is[sm][sn][c];
|
||||
I.thread_elements()[1] = Is[sm][sn + 1][c];
|
||||
|
||||
simdgroup_matrix<T, 8, 8> I_out = (Bt * I) * B;
|
||||
inp_out_0[c] = I_out.thread_elements()[0];
|
||||
inp_out_1[c] = I_out.thread_elements()[1];
|
||||
}
|
||||
|
||||
inp_in += BC;
|
||||
inp_out_0 += BC;
|
||||
inp_out_1 += BC;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
|
||||
template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\
|
||||
[[kernel]] void winograd_conv_2d_input_transform<itype, bc, 2, 2>(\
|
||||
const device itype* inp_in [[buffer(0)]],\
|
||||
device itype* inp_out [[buffer(1)]],\
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]],\
|
||||
uint3 tid [[threadgroup_position_in_grid]],\
|
||||
uint3 lid [[thread_position_in_threadgroup]],\
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],\
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||
|
||||
template <typename T,
|
||||
int BO,
|
||||
int WM,
|
||||
int WN,
|
||||
int M = 6,
|
||||
int R = 3>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform(
|
||||
const device T* out_in [[buffer(0)]],
|
||||
device T* out_out [[buffer(1)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)lid;
|
||||
|
||||
using WGT = WinogradTransforms<M, R, 8>;
|
||||
constexpr int N_SIMD_GROUPS = WM * WN;
|
||||
|
||||
// Get lane position in simdgroup
|
||||
const short qid = simd_lane_id / 4;
|
||||
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Initialize A matrix
|
||||
simdgroup_matrix<T, 8, 8> B;
|
||||
B.thread_elements()[0] = WGT::out_transform[sm][sn];
|
||||
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
|
||||
|
||||
// Initialize At matrix
|
||||
simdgroup_matrix<T, 8, 8> Bt;
|
||||
Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
|
||||
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
|
||||
|
||||
// Out_in comes in shape (A x A x tiles x O)
|
||||
// We do transform and then write out to out_out in shape (N, H, W, O)
|
||||
|
||||
// Resolve output tile
|
||||
constexpr int TH = (M / WM);
|
||||
constexpr int TW = (M / WN);
|
||||
int kh = TH * (simd_group_id / WN);
|
||||
int kw = TW * (simd_group_id % WN);
|
||||
int bh = M * tid.y + kh;
|
||||
int bw = M * tid.x + kw;
|
||||
|
||||
// Move to the correct input tile
|
||||
out_out += tid.z * params.out_strides[0]
|
||||
+ bh * params.out_strides[1]
|
||||
+ bw * params.out_strides[2];
|
||||
|
||||
// Pre compute strides
|
||||
int jump_in[TH][TW];
|
||||
|
||||
for(int h = 0; h < TH; h++) {
|
||||
for(int w = 0; w < TW; w++) {
|
||||
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
|
||||
jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
|
||||
}
|
||||
}
|
||||
|
||||
// out_in is stored interleaved (A x A x tiles x O)
|
||||
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
||||
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||
size_t ohw_0 = sm * 8 + sn;
|
||||
size_t ohw_1 = sm * 8 + sn + 1;
|
||||
const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
|
||||
const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
|
||||
|
||||
// Prepare shared memory
|
||||
threadgroup T Os[M][M][BO];
|
||||
|
||||
// Loop over O
|
||||
for(int bo = 0; bo < params.O; bo += BO) {
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for(int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
||||
simdgroup_matrix<T, 8, 8> O_mat;
|
||||
O_mat.thread_elements()[0] = out_in_0[c];
|
||||
O_mat.thread_elements()[1] = out_in_1[c];
|
||||
|
||||
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
|
||||
if((sm < M) && (sn < M)) {
|
||||
Os[sm][sn][c] = O_out.thread_elements()[0];
|
||||
}
|
||||
if((sm < M) && ((sn + 1) < M)) {
|
||||
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Read out from shared memory
|
||||
for(int h = 0; h < TH; h++) {
|
||||
for(int w = 0; w < TW; w++) {
|
||||
if(jump_in[h][w] >= 0) {
|
||||
device T* out_ptr = out_out + jump_in[h][w];
|
||||
for(int c = simd_lane_id; c < BO; c += 32) {
|
||||
out_ptr[c] = Os[kh + h][kw + w][c];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out_out += BO;
|
||||
out_in_0 += BO;
|
||||
out_in_1 += BO;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
|
||||
template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\
|
||||
[[kernel]] void winograd_conv_2d_output_transform<itype, bo, 2, 2>(\
|
||||
const device itype* out_in [[buffer(0)]],\
|
||||
device itype* out_out [[buffer(1)]],\
|
||||
const constant MLXConvParams<2>& params [[buffer(2)]],\
|
||||
uint3 tid [[threadgroup_position_in_grid]],\
|
||||
uint3 lid [[thread_position_in_threadgroup]],\
|
||||
uint3 tgp_per_grid [[threadgroups_per_grid]],\
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_winograd_conv_2d(name, itype) \
|
||||
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
|
||||
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
|
||||
instantiate_winograd_conv_2d_output_transform(name, itype, 32)
|
||||
|
||||
instantiate_winograd_conv_2d(float32, float);
|
||||
instantiate_winograd_conv_2d(float16, half);
|
269
mlx/backend/metal/kernels/copy.metal
Normal file
269
mlx/backend/metal/kernels/copy.metal
Normal file
@ -0,0 +1,269 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_s(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_v(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t& src_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_g_nd(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int src_shape[DIM],
|
||||
constant const size_t src_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int* src_shape,
|
||||
constant const size_t* src_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd1(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t& src_stride,
|
||||
constant const size_t& dst_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[2],
|
||||
constant const size_t dst_strides[2],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const size_t src_strides[3],
|
||||
constant const size_t dst_strides[3],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_gg_nd(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int src_shape[DIM],
|
||||
constant const size_t src_strides[DIM],
|
||||
constant const size_t dst_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src,
|
||||
device U* dst,
|
||||
constant const int* src_shape,
|
||||
constant const size_t* src_strides,
|
||||
constant const size_t* dst_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
#define instantiate_copy(name, itype, otype, ctype) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void copy_##ctype<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void copy_g_nd<itype, otype, dims>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int src_shape[dims], \
|
||||
constant const size_t src_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_" #dims)]] \
|
||||
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int src_shape[dims], \
|
||||
constant const size_t src_strides[dims], \
|
||||
constant const size_t dst_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
|
||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void copy_g_nd1<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t& src_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void copy_g_nd2<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void copy_g_nd3<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_1")]] \
|
||||
[[kernel]] void copy_gg_nd1<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t& src_stride, \
|
||||
constant const size_t& dst_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_2")]] \
|
||||
[[kernel]] void copy_gg_nd2<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[2], \
|
||||
constant const size_t dst_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_3")]] \
|
||||
[[kernel]] void copy_gg_nd3<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const size_t src_strides[3], \
|
||||
constant const size_t dst_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]]); \
|
||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||
|
||||
|
||||
#define instantiate_copy_g(name, itype, otype) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void copy_g<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int* src_shape, \
|
||||
constant const size_t* src_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name)]] \
|
||||
[[kernel]] void copy_gg<itype, otype>( \
|
||||
device const itype* src, \
|
||||
device otype* dst, \
|
||||
constant const int* src_shape, \
|
||||
constant const size_t* src_strides, \
|
||||
constant const size_t* dst_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_copy("scopy" #tname, itype, otype, s) \
|
||||
instantiate_copy("vcopy" #tname, itype, otype, v) \
|
||||
instantiate_copy_g("gcopy" #tname, itype, otype) \
|
||||
instantiate_copy_g_nd("gcopy" #tname, itype, otype)
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
||||
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
|
||||
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
|
||||
instantiate_copy_all(itname ##int8, itype, int8_t) \
|
||||
instantiate_copy_all(itname ##int16, itype, int16_t) \
|
||||
instantiate_copy_all(itname ##int32, itype, int32_t) \
|
||||
instantiate_copy_all(itname ##int64, itype, int64_t) \
|
||||
instantiate_copy_all(itname ##float16, itype, half) \
|
||||
instantiate_copy_all(itname ##float32, itype, float) \
|
||||
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
|
||||
instantiate_copy_all(itname ##complex64, itype, complex64_t)
|
||||
|
||||
instantiate_copy_itype(bool_, bool)
|
||||
instantiate_copy_itype(uint8, uint8_t)
|
||||
instantiate_copy_itype(uint16, uint16_t)
|
||||
instantiate_copy_itype(uint32, uint32_t)
|
||||
instantiate_copy_itype(uint64, uint64_t)
|
||||
instantiate_copy_itype(int8, int8_t)
|
||||
instantiate_copy_itype(int16, int16_t)
|
||||
instantiate_copy_itype(int32, int32_t)
|
||||
instantiate_copy_itype(int64, int64_t)
|
||||
instantiate_copy_itype(float16, half)
|
||||
instantiate_copy_itype(float32, float)
|
||||
instantiate_copy_itype(bfloat16, bfloat16_t)
|
||||
instantiate_copy_itype(complex64, complex64_t)
|
68
mlx/backend/metal/kernels/erf.h
Normal file
68
mlx/backend/metal/kernels/erf.h
Normal file
@ -0,0 +1,68 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
/*
|
||||
* Approximation to the error function.
|
||||
* Based on code from:
|
||||
* https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
|
||||
*/
|
||||
float erf(float a) {
|
||||
float r, s, t, u;
|
||||
t = metal::abs(a);
|
||||
s = a * a;
|
||||
if (t > 0.927734375f) {
|
||||
// maximum error 0.99527 ulp
|
||||
r = metal::fma(
|
||||
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
|
||||
u = metal::fma(
|
||||
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
|
||||
r = metal::fma(r, s, u);
|
||||
r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
|
||||
r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
|
||||
r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
|
||||
r = metal::fma(r, t, -t);
|
||||
// TODO, replace with expm1 when implemented
|
||||
r = 1.0f - metal::exp(r);
|
||||
r = metal::copysign(r, a);
|
||||
} else {
|
||||
// maximum error 0.98929 ulp
|
||||
r = -5.96761703e-4f; // -0x1.38e000p-11
|
||||
r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
|
||||
r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
|
||||
r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
|
||||
r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
|
||||
r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
|
||||
r = metal::fma(r, a, a);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
float erfinv(float a) {
|
||||
auto t = metal::fma(a, 0.0f - a, 1.0f);
|
||||
t = metal::log(t);
|
||||
float p;
|
||||
if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
||||
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||
p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||
p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||
p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||
p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||
p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||
p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||
p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||
p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||
} else { // maximum ulp error = 2.35002
|
||||
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||
p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||
p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||
p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||
p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||
p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||
p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||
p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||
p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||
p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
}
|
||||
return a * p;
|
||||
}
|
91
mlx/backend/metal/kernels/gemm.metal
Normal file
91
mlx/backend/metal/kernels/gemm.metal
Normal file
@ -0,0 +1,91 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device T *C [[buffer(2)]],
|
||||
const constant int &M [[buffer(3)]],
|
||||
const constant int &N [[buffer(4)]],
|
||||
const constant int &K [[buffer(5)]],
|
||||
const constant int &batch_stride_a [[buffer(6)]],
|
||||
const constant int &batch_stride_b [[buffer(7)]],
|
||||
const constant int &batch_stride_c [[buffer(8)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
|
||||
using gemm_kernel = GEMMKernel<T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||
|
||||
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, C,
|
||||
M, N, K,
|
||||
batch_stride_a, batch_stride_b, batch_stride_c,
|
||||
tgp_memory,
|
||||
simd_lane_id, simd_group_id, tid, lid
|
||||
);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device itype *C [[buffer(2)]], \
|
||||
const constant int &M [[buffer(3)]], \
|
||||
const constant int &N [[buffer(4)]], \
|
||||
const constant int &K [[buffer(5)]], \
|
||||
const constant int &batch_stride_a [[buffer(6)]], \
|
||||
const constant int &batch_stride_b [[buffer(7)]], \
|
||||
const constant int &batch_stride_c [[buffer(8)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
// TODO: Accumulation in different type
|
99
mlx/backend/metal/kernels/random.metal
Normal file
99
mlx/backend/metal/kernels/random.metal
Normal file
@ -0,0 +1,99 @@
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
static constexpr constant uint32_t rotations[2][4] = {
|
||||
{13, 15, 26, 6},
|
||||
{17, 29, 16, 24}
|
||||
};
|
||||
|
||||
union rbits {
|
||||
uint2 val;
|
||||
uchar4 bytes[2];
|
||||
};
|
||||
|
||||
rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
|
||||
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
|
||||
|
||||
rbits v;
|
||||
v.val.x = count.x + ks[0];
|
||||
v.val.y = count.y + ks[1];
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
for (auto r : rotations[i % 2]) {
|
||||
v.val.x += v.val.y;
|
||||
v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));
|
||||
v.val.y ^= v.val.x;
|
||||
}
|
||||
v.val.x += ks[(i + 1) % 3];
|
||||
v.val.y += ks[(i + 2) % 3] + i + 1;
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
[[kernel]] void rbitsc(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto kidx = 2 * index.x;
|
||||
auto key = uint2(keys[kidx], keys[kidx + 1]);
|
||||
auto half_size = grid_dim.y - odd;
|
||||
out += index.x * bytes_per_key;
|
||||
bool drop_last = odd && (index.y == half_size);
|
||||
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
|
||||
auto bits = threefry2x32_hash(key, count);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[4 * count.x + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[kernel]] void rbits(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
device const int& ndim,
|
||||
device const int* key_shape,
|
||||
device const size_t* key_strides,
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto kidx = 2 * index.x;
|
||||
auto k1_elem = elem_to_loc(kidx, key_shape, key_strides, ndim);
|
||||
auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim);
|
||||
auto key = uint2(keys[k1_elem], keys[k2_elem]);
|
||||
auto half_size = grid_dim.y - odd;
|
||||
out += index.x * bytes_per_key;
|
||||
bool drop_last = odd && (index.y == half_size);
|
||||
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
|
||||
auto bits = threefry2x32_hash(key, count);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[4 * count.x + i] = bits.bytes[0][i];
|
||||
}
|
||||
if (!drop_last) {
|
||||
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||
int edge_bytes = (bytes_per_key % 4);
|
||||
for (int i = 0; i < edge_bytes; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
out[4 * count.y + i] = bits.bytes[1][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
536
mlx/backend/metal/kernels/reduce.metal
Normal file
536
mlx/backend/metal/kernels/reduce.metal
Normal file
@ -0,0 +1,536 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
static constant uint8_t simd_size = 32;
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void init_reduce(
|
||||
device T *out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
out[tid] = Op::init;
|
||||
}
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i" #name)]] \
|
||||
[[kernel]] void init_reduce<otype, op>( \
|
||||
device otype *out [[buffer(1)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// NB: this kernel assumes threads_per_threadgroup is at most
|
||||
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||
// complete the reduction in two steps of simd-level reductions.
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = Op::init;
|
||||
|
||||
in += gid * N_READS;
|
||||
|
||||
int r = 0;
|
||||
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
|
||||
U vals[N_READS] = {op.init};
|
||||
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
vals[i] = static_cast<U>(in[i]);
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
total_val = op(vals[i], total_val);
|
||||
}
|
||||
|
||||
in += grid_size * N_READS;
|
||||
}
|
||||
|
||||
// Sepate case for the last set as we close the reduction size
|
||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||
if (curr_idx < in_size) {
|
||||
int max_reads = in_size - curr_idx;
|
||||
T vals[N_READS];
|
||||
|
||||
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
|
||||
idx = idx < max_reads ? idx : max_reads - 1;
|
||||
vals[i] = in[idx];
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
U val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduction within simd group
|
||||
total_val = op.simd_reduce(total_val);
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
|
||||
// Reduction within thread group
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
// Reduction across threadgroups
|
||||
if (lid == 0) {
|
||||
op.atomic_update(out, total_val);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] \
|
||||
[[kernel]] void all_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// General reduce
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void general_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device int *in_shape [[buffer(2)]],
|
||||
const device size_t *in_strides [[buffer(3)]],
|
||||
const device size_t *out_strides [[buffer(4)]],
|
||||
const device size_t& ndim [[buffer(5)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim);
|
||||
auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim);
|
||||
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM>
|
||||
[[kernel]] void general_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const device int *in_shape [[buffer(2)]],
|
||||
const device size_t *in_strides [[buffer(3)]],
|
||||
const device size_t *out_strides [[buffer(4)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
auto in_idx = elem_to_loc_nd<NDIM>(gid, in_shape, in_strides);
|
||||
auto out_idx = elem_to_loc_nd<NDIM>(gid, in_shape, out_strides);
|
||||
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
||||
}
|
||||
|
||||
#define instantiate_general_reduce_helper(name, itype, otype, op) \
|
||||
template [[host_name("general_reduce_" #name)]] \
|
||||
[[kernel]] void general_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device int *in_shape [[buffer(2)]], \
|
||||
const device size_t *in_strides [[buffer(3)]], \
|
||||
const device size_t *out_strides [[buffer(4)]], \
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \
|
||||
template [[host_name("general_reduce_" #name "_dim_" #n)]] \
|
||||
[[kernel]] void general_reduce<itype, otype, op, n>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const device int *in_shape [[buffer(2)]], \
|
||||
const device size_t *in_strides [[buffer(3)]], \
|
||||
const device size_t *out_strides [[buffer(4)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_general_reduce(name, itype, otype, op) \
|
||||
instantiate_general_reduce_helper(name, itype, otype, op) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \
|
||||
instantiate_general_reduce_helper_nd(name, itype, otype, op, 4)
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Row atomics
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device size_t& reduction_size [[buffer(2)]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
// Each threadgroup handles 1 reduction
|
||||
in += tid * reduction_size + lid * N_READS;
|
||||
|
||||
// The reduction is accumulated here
|
||||
U total_val = Op::init;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
// Loop over the reduction size within thread group
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) {
|
||||
T vals[N_READS];
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[i];
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||
}
|
||||
|
||||
in += lsize * N_READS;
|
||||
}
|
||||
|
||||
// Sepate case for the last set as we close the reduction size
|
||||
size_t reduction_index = (lid + (size_t)lsize * r) * N_READS;
|
||||
if(reduction_index < reduction_size) {
|
||||
int max_reads = reduction_size - reduction_index;
|
||||
|
||||
T vals[N_READS];
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
int idx = min(i, max_reads - 1);
|
||||
vals[i] = static_cast<U>(in[idx]);
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
T val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
// Prepare next level
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if multiple simd groups
|
||||
if(reduction_size > simd_size) {
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
}
|
||||
// Update output
|
||||
if (lid == 0) {
|
||||
out[tid] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_row_reduce(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_" #name)]] \
|
||||
[[kernel]] void row_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const device size_t& reduction_size [[buffer(2)]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
inline void _contiguous_strided_reduce(
|
||||
const device T *in,
|
||||
device mlx_atomic<U> *out,
|
||||
threadgroup U *local_data,
|
||||
uint in_idx,
|
||||
uint out_idx,
|
||||
uint reduction_size,
|
||||
uint reduction_stride,
|
||||
uint2 tid,
|
||||
uint2 lid,
|
||||
uint2 lsize) {
|
||||
|
||||
Op op;
|
||||
T local_vals[N_READS];
|
||||
|
||||
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
|
||||
|
||||
for(uint r = 0; r < N_READS; r++) {
|
||||
uint offset = base_offset + r;
|
||||
offset = offset < reduction_size ? offset : reduction_size - 1;
|
||||
local_vals[r] = in[in_idx + offset * reduction_stride];
|
||||
}
|
||||
|
||||
U total_val = Op::init;
|
||||
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
|
||||
total_val = op(static_cast<U>(total_val), local_vals[r]);
|
||||
}
|
||||
local_data[lsize.y * lid.x + lid.y] = total_val;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if(lid.y == 0) {
|
||||
U val = op.init;
|
||||
|
||||
for(uint i = 0; i < lsize.y; i++) {
|
||||
val = op(val, local_data[lsize.y * lid.x + i]);
|
||||
}
|
||||
|
||||
op.atomic_update(out, val, out_idx);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint2 tid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
out_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid,
|
||||
lid,
|
||||
lsize);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_" #name)]] \
|
||||
[[kernel]] void col_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint2 tid [[threadgroup_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
|
||||
[[kernel]] void contiguous_strided_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const device int* in_shape [[buffer(5)]],
|
||||
const device size_t* in_strides [[buffer(6)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint2 tid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]]) {
|
||||
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc_nd<NDIM>(out_idx, in_shape, in_strides);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
in_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid,
|
||||
lid,
|
||||
lsize);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void contiguous_strided_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const device int* in_shape [[buffer(5)]],
|
||||
const device size_t* in_strides [[buffer(6)]],
|
||||
const device size_t& in_dim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint2 tid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]]) {
|
||||
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
in_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid,
|
||||
lid,
|
||||
lsize);
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
||||
template [[host_name("contiguous_strided_reduce_" #name)]] \
|
||||
[[kernel]] void contiguous_strided_reduce<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const device int* in_shape [[buffer(5)]], \
|
||||
const device size_t* in_strides [[buffer(6)]], \
|
||||
const device size_t& in_dim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint2 tid [[threadgroup_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \
|
||||
template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \
|
||||
[[kernel]] void contiguous_strided_reduce<itype, otype, op, n>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const device int* in_shape [[buffer(5)]], \
|
||||
const device size_t* in_strides [[buffer(6)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint2 tid [[threadgroup_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_contiguous_strided(name, itype, otype, op) \
|
||||
instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \
|
||||
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4)
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_reduce(name, itype, otype, op) \
|
||||
instantiate_all_reduce(name, itype, otype, op) \
|
||||
instantiate_row_reduce(name, itype, otype, op) \
|
||||
instantiate_col_reduce(name, itype, otype, op) \
|
||||
instantiate_contiguous_strided(name, itype, otype, op) \
|
||||
instantiate_general_reduce(name, itype, otype, op)
|
||||
|
||||
#define instantiate_same_reduce(name, tname, type, op) \
|
||||
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||
instantiate_reduce(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_reduce_from_types_helper(name, tname, itype, otype, op) \
|
||||
instantiate_reduce(name ##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(name, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, float32, float, otype, op) \
|
||||
instantiate_reduce_from_types_helper(name, bfloat16, bfloat16_t, otype, op)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_same_reduce(sum, uint8, uint8_t, Sum)
|
||||
instantiate_same_reduce(sum, uint16, uint16_t, Sum)
|
||||
instantiate_same_reduce(sum, uint32, uint32_t, Sum)
|
||||
instantiate_same_reduce(sum, int8, int8_t, Sum)
|
||||
instantiate_same_reduce(sum, int16, int16_t, Sum)
|
||||
instantiate_same_reduce(sum, int32, int32_t, Sum)
|
||||
instantiate_same_reduce(sum, float16, half, Sum)
|
||||
instantiate_same_reduce(sum, float32, float, Sum)
|
||||
|
||||
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
|
||||
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
|
||||
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
|
||||
instantiate_same_reduce(prod, int8, int8_t, Prod)
|
||||
instantiate_same_reduce(prod, int16, int16_t, Prod)
|
||||
instantiate_same_reduce(prod, int32, int32_t, Prod)
|
||||
instantiate_same_reduce(prod, float16, half, Prod)
|
||||
instantiate_same_reduce(prod, float32, float, Prod)
|
||||
|
||||
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
|
||||
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And)
|
||||
instantiate_reduce_from_types(and, bool, And)
|
||||
|
||||
instantiate_init_reduce(orbool_, bool, Or)
|
||||
instantiate_reduce_from_types(or, bool, Or)
|
||||
|
||||
// Compiler segfaulted with the names "min" or "max" ...
|
||||
instantiate_same_reduce(min_, uint8, uint8_t, Min)
|
||||
instantiate_same_reduce(min_, uint16, uint16_t, Min)
|
||||
instantiate_same_reduce(min_, uint32, uint32_t, Min)
|
||||
instantiate_same_reduce(min_, int8, int8_t, Min)
|
||||
instantiate_same_reduce(min_, int16, int16_t, Min)
|
||||
instantiate_same_reduce(min_, int32, int32_t, Min)
|
||||
instantiate_same_reduce(min_, float16, half, Min)
|
||||
instantiate_same_reduce(min_, float32, float, Min)
|
||||
|
||||
instantiate_same_reduce(max_, uint8, uint8_t, Max)
|
||||
instantiate_same_reduce(max_, uint16, uint16_t, Max)
|
||||
instantiate_same_reduce(max_, uint32, uint32_t, Max)
|
||||
instantiate_same_reduce(max_, int8, int8_t, Max)
|
||||
instantiate_same_reduce(max_, int16, int16_t, Max)
|
||||
instantiate_same_reduce(max_, int32, int32_t, Max)
|
||||
instantiate_same_reduce(max_, float16, half, Max)
|
||||
instantiate_same_reduce(max_, float32, float, Max)
|
||||
|
||||
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
|
||||
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)
|
284
mlx/backend/metal/kernels/unary.metal
Normal file
284
mlx/backend/metal/kernels/unary.metal
Normal file
@ -0,0 +1,284 @@
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct Abs {
|
||||
template <typename T> T operator()(T x) { return metal::abs(x); };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
template <> complex64_t operator()(complex64_t x) {
|
||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::acos(x); };
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::acosh(x); };
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T> T operator()(T x) { return metal::precise::asin(x); };
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::asinh(x); };
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atan(x); };
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cosh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T> T operator()(T x) { return static_cast<T>(erf(static_cast<float>(x))); };
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T> T operator()(T x) { return static_cast<T>(erfinv(static_cast<float>(x))); };
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T> T operator()(T x) { return metal::precise::exp(x); };
|
||||
template <> complex64_t operator()(complex64_t x) {
|
||||
auto m = metal::precise::exp(x.real);
|
||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log(x); };
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log2(x); };
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log10(x); };
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T> T operator()(T x) { return log1p(x); };
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T> T operator()(T x) { return !x; };
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T> T operator()(T x) { return -x; };
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T> T operator()(T x) { return (x > T(0)) - (x < T(0)); };
|
||||
template <> uint32_t operator()(uint32_t x) { return x != 0; };
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sin(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sinh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T> T operator()(T x) { return x * x; };
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sqrt(x); };
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T> T operator()(T x) { return metal::precise::rsqrt(x); };
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T> T operator()(T x) { return metal::precise::tan(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tan_a = metal::precise::tan(x.real);
|
||||
float tanh_b = metal::precise::tanh(x.imag);
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {
|
||||
(tan_a - tanh_b * t1) / denom,
|
||||
(tanh_b + tan_a * t1) / denom
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::tanh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tanh_a = metal::precise::tanh(x.real);
|
||||
float tan_b = metal::precise::tan(x.imag);
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {
|
||||
(tanh_a + tan_b * t1) / denom,
|
||||
(tan_b - tanh_a * t1) / denom
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void unary_op_v(
|
||||
device const T* in,
|
||||
device T* out,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = Op()(in[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void unary_op_g(
|
||||
device const T* in,
|
||||
device T* out,
|
||||
device const int* in_shape,
|
||||
device const size_t* in_strides,
|
||||
device const int& ndim,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
|
||||
out[index] = Op()(in[idx]);
|
||||
}
|
||||
|
||||
#define instantiate_unary_v(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void unary_op_v<type, op>( \
|
||||
device const type* in, \
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_unary_g(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void unary_op_g<type, op>( \
|
||||
device const type* in, \
|
||||
device type* out, \
|
||||
device const int* in_shape, \
|
||||
device const size_t* in_strides, \
|
||||
device const int& ndim, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_unary_all(name, tname, type, op) \
|
||||
instantiate_unary_v("v" #name #tname, type, op) \
|
||||
instantiate_unary_g("g" #name #tname, type, op)
|
||||
|
||||
#define instantiate_unary_float(name, op) \
|
||||
instantiate_unary_all(name, float16, half, op) \
|
||||
instantiate_unary_all(name, float32, float, op) \
|
||||
instantiate_unary_all(name, bfloat16, bfloat16_t, op) \
|
||||
|
||||
#define instantiate_unary_types(name, op) \
|
||||
instantiate_unary_all(name, bool_, bool, op) \
|
||||
instantiate_unary_all(name, uint8, uint8_t, op) \
|
||||
instantiate_unary_all(name, uint16, uint16_t, op) \
|
||||
instantiate_unary_all(name, uint32, uint32_t, op) \
|
||||
instantiate_unary_all(name, uint64, uint64_t, op) \
|
||||
instantiate_unary_all(name, int8, int8_t, op) \
|
||||
instantiate_unary_all(name, int16, int16_t, op) \
|
||||
instantiate_unary_all(name, int32, int32_t, op) \
|
||||
instantiate_unary_all(name, int64, int64_t, op) \
|
||||
instantiate_unary_float(name, op)
|
||||
|
||||
instantiate_unary_types(abs, Abs)
|
||||
instantiate_unary_float(arccos, ArcCos)
|
||||
instantiate_unary_float(arccosh, ArcCosh)
|
||||
instantiate_unary_float(arcsin, ArcSin)
|
||||
instantiate_unary_float(arcsinh, ArcSinh)
|
||||
instantiate_unary_float(arctan, ArcTan)
|
||||
instantiate_unary_float(arctanh, ArcTanh)
|
||||
instantiate_unary_float(cos, Cos)
|
||||
instantiate_unary_float(cosh, Cosh)
|
||||
instantiate_unary_float(exp, Exp)
|
||||
instantiate_unary_float(log, Log)
|
||||
instantiate_unary_float(log2, Log2)
|
||||
instantiate_unary_float(log10, Log10)
|
||||
instantiate_unary_float(log1p, Log1p)
|
||||
instantiate_unary_types(neg, Negative)
|
||||
instantiate_unary_float(sigmoid, Sigmoid)
|
||||
instantiate_unary_float(erf, Erf)
|
||||
instantiate_unary_float(erfinv, ErfInv)
|
||||
instantiate_unary_types(sign, Sign)
|
||||
instantiate_unary_float(sin, Sin)
|
||||
instantiate_unary_float(sinh, Sinh)
|
||||
instantiate_unary_types(square, Square)
|
||||
instantiate_unary_float(sqrt, Sqrt)
|
||||
instantiate_unary_float(rsqrt, Rsqrt)
|
||||
instantiate_unary_float(tan, Tan)
|
||||
instantiate_unary_float(tanh, Tanh)
|
||||
|
||||
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
||||
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
||||
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
|
||||
instantiate_unary_all(exp, complex64, complex64_t, Exp)
|
||||
instantiate_unary_all(neg, complex64, complex64_t, Negative)
|
||||
instantiate_unary_all(sin, complex64, complex64_t, Sin)
|
||||
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
|
||||
instantiate_unary_all(tan, complex64, complex64_t, Tan)
|
||||
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
||||
|
||||
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
|
369
mlx/backend/metal/reduce.cpp
Normal file
369
mlx/backend/metal/reduce.cpp
Normal file
@ -0,0 +1,369 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// Case wise reduce dispatch
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
// All Reduce
|
||||
void all_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
// Get kernel and encode buffers
|
||||
size_t in_size = in.size();
|
||||
auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in));
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
|
||||
// Set grid dimensions
|
||||
|
||||
// We make sure each thread has enough to do by making it read in
|
||||
// atleast n_reads inputs
|
||||
int n_reads = REDUCE_N_READS;
|
||||
|
||||
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
||||
// input
|
||||
uint mod_in_size = (in_size + n_reads - 1) / n_reads;
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
thread_group_size =
|
||||
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
|
||||
|
||||
// If the number of thread groups needed exceeds 1024, we reuse threads groups
|
||||
uint n_thread_groups =
|
||||
(mod_in_size + thread_group_size - 1) / thread_group_size;
|
||||
n_thread_groups = std::min(n_thread_groups, 1024u);
|
||||
uint nthreads = n_thread_groups * thread_group_size;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void row_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const std::vector<int>& axes_,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in));
|
||||
|
||||
int n_reads = REDUCE_N_READS;
|
||||
size_t reduction_size = in.size() / out.size();
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
|
||||
// Each thread group is responsible for 1 output
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
thread_group_size =
|
||||
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
|
||||
|
||||
// Align thread group size with simd_size
|
||||
uint simd_size = kernel->threadExecutionWidth();
|
||||
thread_group_size =
|
||||
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
size_t n_threads = out.size() * thread_group_size;
|
||||
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void col_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const std::vector<int>& axes_,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
std::ostringstream kernel_name;
|
||||
|
||||
bool encode_in_shape = false;
|
||||
bool encode_ndim = false;
|
||||
|
||||
// If the slowest moving axis can be merged into the reductions,
|
||||
// we call the column reduce kernel
|
||||
// In this case, a linear index in the output corresponds to the
|
||||
// linear index in the input where the reduction starts
|
||||
if (axes_[axes_.size() - 1] == (axes_.size() - 1)) {
|
||||
kernel_name << "col_reduce_" << op_name << type_to_name(in);
|
||||
}
|
||||
// Otherwise, while all the reduction axes can be merged, the mapping between
|
||||
// indices in the output and input require resolving using shapes and strides
|
||||
else {
|
||||
kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in);
|
||||
encode_in_shape = true;
|
||||
|
||||
// We check for a viable template with the required number of dimensions
|
||||
// we only care about encoding non-reduced shapes and strides in the input
|
||||
size_t non_reducing_dims = in.ndim() - axes_.size();
|
||||
if (non_reducing_dims >= 1 &&
|
||||
non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) {
|
||||
kernel_name << "_dim_" << non_reducing_dims;
|
||||
} else {
|
||||
encode_ndim = true;
|
||||
}
|
||||
}
|
||||
|
||||
auto kernel = d.get_kernel(kernel_name.str());
|
||||
size_t in_size = in.size();
|
||||
size_t out_size = out.size();
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
// Calculate the number of inputs to reduce and the stride b/w them
|
||||
size_t reduction_size = 1;
|
||||
size_t in_ndim = in.ndim();
|
||||
size_t reduction_stride = in_size;
|
||||
|
||||
for (int i : axes_) {
|
||||
reduction_size *= in.shape(i);
|
||||
reduction_stride = std::min(reduction_stride, in.strides()[i]);
|
||||
}
|
||||
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
if (encode_in_shape) {
|
||||
// Obtain the non-reducing shape and strides of the input to encode
|
||||
std::vector<int> inp_shape_mod;
|
||||
std::vector<size_t> inp_strides_mod;
|
||||
|
||||
for (size_t i = 0, j = 0; i < in.ndim(); i++) {
|
||||
if (j < axes_.size() && axes_[j] == i) {
|
||||
j++;
|
||||
} else {
|
||||
inp_shape_mod.push_back(in.shape(i));
|
||||
inp_strides_mod.push_back(in.strides()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
size_t ndim = inp_shape_mod.size();
|
||||
|
||||
compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6);
|
||||
|
||||
if (encode_ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 7);
|
||||
}
|
||||
}
|
||||
|
||||
// Select block dimensions
|
||||
|
||||
// Each thread reads 16 inputs to give it more work
|
||||
uint n_inputs_per_thread = REDUCE_N_READS;
|
||||
uint n_threads_per_output =
|
||||
(reduction_size + n_inputs_per_thread - 1) / n_inputs_per_thread;
|
||||
|
||||
// We spread outputs over the x dimension and inputs over the y dimension
|
||||
// Threads with the same lid.x in a given threadgroup work on the same
|
||||
// output and each thread in the y dimension accumlates for that output
|
||||
uint threadgroup_dim_x = std::min(out_size, 128ul);
|
||||
uint threadgroup_dim_y =
|
||||
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
|
||||
threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y);
|
||||
|
||||
uint n_threadgroups_x =
|
||||
(out_size + threadgroup_dim_x - 1) / threadgroup_dim_x;
|
||||
|
||||
uint n_threadgroups_y =
|
||||
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1);
|
||||
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
|
||||
|
||||
// We set shared memory to be exploited here for reductions within a
|
||||
// threadgroup - each thread must be able to update its accumulated output
|
||||
// Note: Each threadgroup should have 32kB of data in threadgroup memory
|
||||
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
|
||||
// This should be fine for floats, but we might need to revisit
|
||||
// if we ever come to doubles. In that case, we should also cut
|
||||
// down the number of threads we launch in a threadgroup
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 0);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void general_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
const std::vector<int>& axes_,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
bool encode_ndim = true;
|
||||
std::ostringstream kernel_name;
|
||||
kernel_name << "general_reduce_" << op_name << type_to_name(in);
|
||||
|
||||
// Check for specialzed kernels for input ndim
|
||||
if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) {
|
||||
kernel_name << "_dim_" << in.ndim();
|
||||
encode_ndim = false;
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name.str());
|
||||
size_t in_size = in.size();
|
||||
size_t ndim = in.ndim();
|
||||
|
||||
// We set the reducing strides to 0 to induce collisions for the reduction
|
||||
std::vector<size_t> out_strides(ndim);
|
||||
size_t stride = 1;
|
||||
for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) {
|
||||
if (j >= 0 && axes_[j] == i) {
|
||||
out_strides[i] = 0;
|
||||
--j;
|
||||
} else {
|
||||
out_strides[i] = stride;
|
||||
stride *= in.shape(i);
|
||||
}
|
||||
}
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
||||
if (encode_ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
}
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > in_size) {
|
||||
thread_group_size = in_size;
|
||||
}
|
||||
size_t nthreads = in_size;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// Main reduce dispatch
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// TODO: Allow specific row and column reductions with types disabled
|
||||
// due to atomics ?
|
||||
if (size_of(in.dtype()) == 8) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Reduce::eval_gpu] Does not support " << in.dtype();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Make sure no identity reductions trickle down here
|
||||
assert(!axes_.empty());
|
||||
|
||||
// Continue with reduction operation
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
op_name = "and";
|
||||
break;
|
||||
case Reduce::Or:
|
||||
op_name = "or";
|
||||
break;
|
||||
case Reduce::Sum:
|
||||
op_name = "sum";
|
||||
break;
|
||||
case Reduce::Prod:
|
||||
op_name = out.dtype() == bool_ ? "and" : "prod";
|
||||
break;
|
||||
case Reduce::Min:
|
||||
op_name = out.dtype() == bool_ ? "and" : "min_";
|
||||
break;
|
||||
case Reduce::Max:
|
||||
op_name = out.dtype() == bool_ ? "or" : "max_";
|
||||
break;
|
||||
}
|
||||
|
||||
// Initialize output
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel("i" + op_name + type_to_name(out));
|
||||
size_t nthreads = out.size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, out, 0);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Reduce
|
||||
{
|
||||
// Check for contiguous data
|
||||
if (in.size() == in.data_size() &&
|
||||
(in.flags().row_contiguous || in.flags().col_contiguous)) {
|
||||
// Go to all reduce if reducing over all axes
|
||||
if (axes_.size() == in.ndim()) {
|
||||
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
|
||||
return;
|
||||
}
|
||||
// Use specialized kernels if the input is row contiguous and
|
||||
// the reducing axes can be merged into one
|
||||
else if (
|
||||
in.flags().row_contiguous && in.strides().back() == 1 &&
|
||||
(axes_.back() - axes_.front()) == axes_.size() - 1) {
|
||||
// If the fastest moving axis is being reduced, go to row reduce
|
||||
if (axes_[0] == (in.ndim() - axes_.size())) {
|
||||
row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
||||
return;
|
||||
}
|
||||
// Otherwise go to to generalized strided reduce
|
||||
// Note: bool isn't support here yet due to the use of atomics
|
||||
// once that is updated, this should be the else condition of this
|
||||
// branch
|
||||
else if (in.dtype() != bool_) {
|
||||
col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fall back to the general case
|
||||
general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
15
mlx/backend/no_metal/allocator.cpp
Normal file
15
mlx/backend/no_metal/allocator.cpp
Normal file
@ -0,0 +1,15 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
static CommonAllocator allocator_;
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
void* Buffer::raw_ptr() {
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
149
mlx/fft.h
Normal file
149
mlx/fft.h
Normal file
@ -0,0 +1,149 @@
|
||||
#pragma once
|
||||
|
||||
#include <variant>
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "stream.h"
|
||||
|
||||
namespace mlx::core::fft {
|
||||
|
||||
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
||||
|
||||
/** Compute the n-dimensional Fourier Transform. */
|
||||
array fftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array fftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
|
||||
array fftn(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Compute the n-dimensional inverse Fourier Transform. */
|
||||
array ifftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array ifftn(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array ifftn(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Compute the one-dimensional Fourier Transform. */
|
||||
inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
||||
return fftn(a, {n}, {axis}, s);
|
||||
}
|
||||
inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
||||
return fftn(a, {axis}, s);
|
||||
}
|
||||
|
||||
/** Compute the one-dimensional inverse Fourier Transform. */
|
||||
inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
||||
return ifftn(a, {n}, {axis}, s);
|
||||
}
|
||||
inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
||||
return ifftn(a, {axis}, s);
|
||||
}
|
||||
|
||||
/** Compute the two-dimensional Fourier Transform. */
|
||||
inline array fft2(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {}) {
|
||||
return fftn(a, n, axes, s);
|
||||
}
|
||||
inline array fft2(
|
||||
const array& a,
|
||||
const std::vector<int>& axes = {-2, -1},
|
||||
StreamOrDevice s = {}) {
|
||||
return fftn(a, axes, s);
|
||||
}
|
||||
|
||||
/** Compute the two-dimensional inverse Fourier Transform. */
|
||||
inline array ifft2(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {}) {
|
||||
return ifftn(a, n, axes, s);
|
||||
}
|
||||
inline array ifft2(
|
||||
const array& a,
|
||||
const std::vector<int>& axes = {-2, -1},
|
||||
StreamOrDevice s = {}) {
|
||||
return ifftn(a, axes, s);
|
||||
}
|
||||
|
||||
/** Compute the n-dimensional Fourier Transform on a real input. */
|
||||
array rfftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array rfftn(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array rfftn(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Compute the n-dimensional inverse of `rfftn`. */
|
||||
array irfftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array irfftn(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {});
|
||||
array irfftn(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/** Compute the one-dimensional Fourier Transform on a real input. */
|
||||
inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
||||
return rfftn(a, {n}, {axis}, s);
|
||||
}
|
||||
inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
||||
return rfftn(a, {axis}, s);
|
||||
}
|
||||
/** Compute the one-dimensional inverse of `rfft`. */
|
||||
inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
||||
return irfftn(a, {n}, {axis}, s);
|
||||
}
|
||||
inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
||||
return irfftn(a, {axis}, s);
|
||||
}
|
||||
|
||||
/** Compute the two-dimensional Fourier Transform on a real input. */
|
||||
inline array rfft2(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {}) {
|
||||
return rfftn(a, n, axes, s);
|
||||
}
|
||||
inline array rfft2(
|
||||
const array& a,
|
||||
const std::vector<int>& axes = {-2, -1},
|
||||
StreamOrDevice s = {}) {
|
||||
return rfftn(a, axes, s);
|
||||
}
|
||||
|
||||
/** Compute the two-dimensional inverse of `rfft2`. */
|
||||
inline array irfft2(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s = {}) {
|
||||
return irfftn(a, n, axes, s);
|
||||
}
|
||||
inline array irfft2(
|
||||
const array& a,
|
||||
const std::vector<int>& axes = {-2, -1},
|
||||
StreamOrDevice s = {}) {
|
||||
return irfftn(a, axes, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fft
|
240
mlx/load.cpp
Normal file
240
mlx/load.cpp
Normal file
@ -0,0 +1,240 @@
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/load.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
// Adapted from
|
||||
// https://github.com/angeloskath/supervised-lda/blob/master/include/ldaplusplus/NumpyFormat.hpp
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr uint8_t MAGIC[] = {
|
||||
0x93,
|
||||
0x4e,
|
||||
0x55,
|
||||
0x4d,
|
||||
0x50,
|
||||
0x59,
|
||||
};
|
||||
|
||||
inline bool is_big_endian_() {
|
||||
union ByteOrder {
|
||||
int32_t i;
|
||||
uint8_t c[4];
|
||||
};
|
||||
ByteOrder b = {0x01234567};
|
||||
|
||||
return b.c[0] == 0x01;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/** Save array to out stream in .npy format */
|
||||
void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Check array
|
||||
|
||||
a.eval(retain_graph);
|
||||
|
||||
if (a.nbytes() == 0) {
|
||||
throw std::invalid_argument("[save] cannot serialize an empty array");
|
||||
}
|
||||
|
||||
if (!a.flags().contiguous) {
|
||||
throw std::invalid_argument(
|
||||
"[save] cannot serialize a non-contiguous array");
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Check file
|
||||
if (!out_stream->good() || !out_stream->is_open()) {
|
||||
throw std::runtime_error("[save] Failed to open " + out_stream->label());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Prepare header
|
||||
std::ostringstream magic_ver_len;
|
||||
magic_ver_len.write(reinterpret_cast<const char*>(MAGIC), 6);
|
||||
|
||||
std::string fortran_order = a.flags().col_contiguous ? "True" : "False";
|
||||
std::ostringstream header;
|
||||
header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "',"
|
||||
<< " 'fortran_order': " << fortran_order << ","
|
||||
<< " 'shape': (";
|
||||
for (auto i : a.shape()) {
|
||||
header << i << ", ";
|
||||
}
|
||||
header << ")}";
|
||||
|
||||
size_t header_len = static_cast<size_t>(header.tellp());
|
||||
bool is_v1 = header_len + 15 < std::numeric_limits<uint16_t>::max();
|
||||
|
||||
// Pad out magic + version + header_len + header + \n to be divisible by 16
|
||||
size_t padding = (6 + 2 + (2 + 2 * is_v1) + header_len + 1) % 16;
|
||||
|
||||
header << std::string(padding, ' ') << '\n';
|
||||
|
||||
if (is_v1) {
|
||||
magic_ver_len << (char)0x01 << (char)0x00;
|
||||
|
||||
uint16_t v1_header_len = header.tellp();
|
||||
const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);
|
||||
|
||||
if (!is_big_endian_()) {
|
||||
magic_ver_len.write(len_bytes, 2);
|
||||
} else {
|
||||
magic_ver_len.write(len_bytes + 1, 1);
|
||||
magic_ver_len.write(len_bytes, 1);
|
||||
}
|
||||
} else {
|
||||
magic_ver_len << (char)0x02 << (char)0x00;
|
||||
|
||||
uint32_t v2_header_len = header.tellp();
|
||||
const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);
|
||||
|
||||
if (!is_big_endian_()) {
|
||||
magic_ver_len.write(len_bytes, 4);
|
||||
} else {
|
||||
magic_ver_len.write(len_bytes + 3, 1);
|
||||
magic_ver_len.write(len_bytes + 2, 1);
|
||||
magic_ver_len.write(len_bytes + 1, 1);
|
||||
magic_ver_len.write(len_bytes, 1);
|
||||
}
|
||||
}
|
||||
////////////////////////////////////////////////////////
|
||||
// Serialize array
|
||||
|
||||
out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length());
|
||||
out_stream->write(header.str().c_str(), header.str().length());
|
||||
out_stream->write(a.data<char>(), a.nbytes());
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
/** Save array to file in .npy format */
|
||||
void save(const std::string& file_, array a, bool retain_graph) {
|
||||
// Open and check file
|
||||
std::string file = file_;
|
||||
|
||||
// Add .npy to file name if it is not there
|
||||
if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy")
|
||||
file += ".npy";
|
||||
|
||||
// Serialize array
|
||||
save(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
||||
}
|
||||
|
||||
/** Load array from reader in .npy format */
|
||||
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Open and check file
|
||||
if (!in_stream->good() || !in_stream->is_open()) {
|
||||
throw std::runtime_error("[load] Failed to open " + in_stream->label());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Read header and prepare array details
|
||||
|
||||
// Read and check magic
|
||||
char read_magic_and_ver[8];
|
||||
in_stream->read(read_magic_and_ver, 8);
|
||||
if (std::memcmp(read_magic_and_ver, MAGIC, 6) != 0) {
|
||||
throw std::runtime_error("[load] Invalid header in " + in_stream->label());
|
||||
}
|
||||
|
||||
// Read and check version
|
||||
if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) {
|
||||
throw std::runtime_error(
|
||||
"[load] Unsupport npy format version in " + in_stream->label());
|
||||
}
|
||||
|
||||
// Read header len and header
|
||||
int header_len_size = read_magic_and_ver[6] == 1 ? 2 : 4;
|
||||
size_t header_len;
|
||||
|
||||
if (header_len_size == 2) {
|
||||
uint16_t v1_header_len;
|
||||
in_stream->read(reinterpret_cast<char*>(&v1_header_len), header_len_size);
|
||||
header_len = v1_header_len;
|
||||
} else {
|
||||
uint32_t v2_header_len;
|
||||
in_stream->read(reinterpret_cast<char*>(&v2_header_len), header_len_size);
|
||||
header_len = v2_header_len;
|
||||
}
|
||||
|
||||
// Read the header
|
||||
std::vector<char> buffer(header_len + 1);
|
||||
in_stream->read(&buffer[0], header_len);
|
||||
buffer[header_len] = 0;
|
||||
std::string header(&buffer[0]);
|
||||
|
||||
// Read data type from header
|
||||
std::string dtype_str = header.substr(11, 3);
|
||||
bool read_is_big_endian = dtype_str[0] == '>';
|
||||
Dtype dtype = dtype_from_array_protocol(dtype_str);
|
||||
|
||||
// Read contiguity order
|
||||
bool col_contiguous = header[34] == 'T';
|
||||
|
||||
// Read array shape from header
|
||||
std::vector<int> shape;
|
||||
|
||||
size_t st = header.find_last_of('(') + 1;
|
||||
size_t ed = header.find_last_of(')');
|
||||
std::string shape_str = header.substr(st, ed - st);
|
||||
|
||||
while (!shape_str.empty()) {
|
||||
// Read current number and get position of comma
|
||||
size_t pos;
|
||||
int dim = std::stoi(shape_str, &pos);
|
||||
shape.push_back(dim);
|
||||
|
||||
// Skip the comma and space and read the next number
|
||||
if (pos + 2 <= shape_str.length())
|
||||
shape_str = shape_str.substr(pos + 2);
|
||||
else {
|
||||
shape_str = shape_str.substr(pos);
|
||||
if (!shape_str.empty() && shape_str != " " && shape_str != ",") {
|
||||
throw std::runtime_error(
|
||||
"[load] Unknown error while parsing header in " +
|
||||
in_stream->label());
|
||||
}
|
||||
shape_str = "";
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Build primitive
|
||||
|
||||
size_t offset = 8 + header_len_size + header.length();
|
||||
bool swap_endianness = read_is_big_endian != is_big_endian_();
|
||||
|
||||
if (col_contiguous) {
|
||||
std::reverse(shape.begin(), shape.end());
|
||||
}
|
||||
auto loaded_array = array(
|
||||
shape,
|
||||
dtype,
|
||||
std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
||||
std::vector<array>{});
|
||||
if (col_contiguous) {
|
||||
loaded_array = transpose(loaded_array, s);
|
||||
}
|
||||
|
||||
return loaded_array;
|
||||
}
|
||||
|
||||
/** Load array from file in .npy format */
|
||||
array load(const std::string& file, StreamOrDevice s) {
|
||||
return load(std::make_shared<io::FileReader>(file), s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
2265
mlx/primitives.cpp
Normal file
2265
mlx/primitives.cpp
Normal file
File diff suppressed because it is too large
Load Diff
778
mlx/transforms.cpp
Normal file
778
mlx/transforms.cpp
Normal file
@ -0,0 +1,778 @@
|
||||
#include <algorithm>
|
||||
#include <future>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void simplify(const std::vector<array>& outputs) {
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
||||
parents_map;
|
||||
|
||||
// Helpers to identify identical scalars
|
||||
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
||||
auto is_scalar = [](const array& a) {
|
||||
return a.is_evaled() && a.ndim() == 0;
|
||||
};
|
||||
auto get_scalar_rep = [](const array& a) {
|
||||
uint64_t v = 0;
|
||||
int dtype;
|
||||
switch (a.dtype().size) {
|
||||
case 1:
|
||||
v = *a.data<uint8_t>();
|
||||
break;
|
||||
case 4:
|
||||
v = *a.data<uint32_t>();
|
||||
break;
|
||||
case 8:
|
||||
v = *a.data<uint64_t>();
|
||||
break;
|
||||
}
|
||||
return std::make_pair(v, a.dtype().val);
|
||||
};
|
||||
|
||||
// DFS the graph to log the parents
|
||||
recurse = [&](const array& a) {
|
||||
auto id = a.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < a.inputs().size(); i++) {
|
||||
auto& in = a.inputs()[i];
|
||||
parents_map[in.id()].push_back({a, i});
|
||||
recurse(in);
|
||||
}
|
||||
cache.insert(id);
|
||||
tape.push(a);
|
||||
if (is_scalar(a)) {
|
||||
scalars.insert({get_scalar_rep(a), a});
|
||||
}
|
||||
};
|
||||
for (auto& a : outputs) {
|
||||
recurse(a);
|
||||
}
|
||||
|
||||
// Helper that fuses two arrays in the graph by setting the parents of the
|
||||
// source to point to the destination
|
||||
auto fuse = [&](array& dst, array& src) {
|
||||
auto src_parents = parents_map.find(src.id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& pairs = parents_map[dst.id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.editable_inputs()[parent.second] = dst;
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
};
|
||||
|
||||
// Walk the graph
|
||||
cache.clear();
|
||||
|
||||
// Depth-1 array equivalence check.
|
||||
auto array_equivalent = [](const array& a, const array& b) {
|
||||
if (!a.has_primitive() || !b.has_primitive()) {
|
||||
return false;
|
||||
}
|
||||
const auto& pa = a.primitive();
|
||||
const auto& pb = b.primitive();
|
||||
if (typeid(pa) != typeid(pb)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (a.inputs().size() != b.inputs().size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < a.inputs().size(); i++) {
|
||||
if (a.inputs()[i].id() != b.inputs()[i].id()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return pa.is_equivalent(pb);
|
||||
};
|
||||
|
||||
while (!tape.empty()) {
|
||||
auto arr = std::move(tape.front());
|
||||
tape.pop();
|
||||
|
||||
if (cache.find(arr.id()) != cache.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if we can fuse scalars
|
||||
if (is_scalar(arr)) {
|
||||
auto scalar = scalars.find(get_scalar_rep(arr));
|
||||
if (scalar->second.id() != arr.id()) {
|
||||
fuse(scalar->second, arr);
|
||||
arr = scalar->second;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we can fuse the parents of this array
|
||||
auto parents = parents_map.find(arr.id());
|
||||
if (parents != parents_map.end()) {
|
||||
std::vector<bool> mask(parents->second.size(), false);
|
||||
auto N = parents->second.size();
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (mask[i]) {
|
||||
continue;
|
||||
}
|
||||
for (int j = i + 1; j < N; j++) {
|
||||
if (mask[j]) {
|
||||
continue;
|
||||
}
|
||||
auto& src = parents->second[j].first;
|
||||
auto& dst = parents->second[i].first;
|
||||
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
||||
cache.insert(src.id());
|
||||
fuse(dst, src);
|
||||
mask[j] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||
if (!retain_graph) {
|
||||
for (auto& out : outputs) {
|
||||
if (out.has_primitive() && out.is_tracer()) {
|
||||
throw std::invalid_argument(
|
||||
"[eval] Illegal to eval an array during "
|
||||
"function transform without graph retention.");
|
||||
}
|
||||
}
|
||||
}
|
||||
std::function<void(const array&)> recurse;
|
||||
std::queue<array> tape;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
||||
|
||||
recurse = [&](const array& a) {
|
||||
auto id = a.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
for (auto in : a.inputs()) {
|
||||
recurse(in);
|
||||
// If one of the inputs is being computed on a different
|
||||
// stream, we need to manage the dependency.
|
||||
if (!in.is_evaled()) {
|
||||
if (a.primitive().stream() != in.primitive().stream()) {
|
||||
deps.insert({in.id(), std::shared_future<void>{}});
|
||||
}
|
||||
}
|
||||
}
|
||||
cache.insert(id);
|
||||
if (!a.is_evaled() || (!retain_graph && a.has_primitive())) {
|
||||
if (!a.has_primitive()) {
|
||||
throw std::invalid_argument(
|
||||
"[eval] Attempting to eval an array without a primitive.");
|
||||
}
|
||||
tape.push(a);
|
||||
}
|
||||
};
|
||||
|
||||
for (auto& arr : outputs) {
|
||||
if (!arr.is_evaled() || (!retain_graph && arr.has_primitive())) {
|
||||
recurse(arr);
|
||||
// Insert a dependency for every output to synchronize
|
||||
// with at the end.
|
||||
if (!arr.is_evaled()) {
|
||||
deps.insert({arr.id(), std::shared_future<void>{}});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while (!tape.empty()) {
|
||||
auto arr = std::move(tape.front());
|
||||
tape.pop();
|
||||
if (arr.is_evaled()) {
|
||||
if (!retain_graph && arr.has_primitive()) {
|
||||
arr.detach();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto stream = arr.primitive().stream();
|
||||
std::vector<std::shared_future<void>> arr_deps;
|
||||
for (auto& in : arr.inputs()) {
|
||||
if (auto it = deps.find(in.id()); it != deps.end()) {
|
||||
arr_deps.push_back(it->second);
|
||||
}
|
||||
}
|
||||
std::shared_ptr<std::promise<void>> p{nullptr};
|
||||
if (auto it = deps.find(arr.id()); it != deps.end()) {
|
||||
p = std::make_unique<std::promise<void>>();
|
||||
it->second = p->get_future().share();
|
||||
}
|
||||
|
||||
if (arr.primitive().device() == Device::gpu) {
|
||||
if (!metal::is_available()) {
|
||||
throw std::runtime_error("Metal GPU is not available.");
|
||||
}
|
||||
scheduler::enqueue(
|
||||
stream,
|
||||
metal::make_task(
|
||||
arr, std::move(arr_deps), std::move(p), retain_graph));
|
||||
} else {
|
||||
auto task = [retain_graph,
|
||||
arr,
|
||||
stream,
|
||||
arr_deps = std::move(arr_deps),
|
||||
p = std::move(p)]() mutable {
|
||||
for (auto& d : arr_deps) {
|
||||
d.wait();
|
||||
}
|
||||
scheduler::notify_new_task(stream);
|
||||
arr.primitive().eval_cpu(arr.inputs(), arr);
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
if (p) {
|
||||
p->set_value();
|
||||
}
|
||||
scheduler::notify_task_completion(stream);
|
||||
};
|
||||
scheduler::enqueue(stream, std::move(task));
|
||||
}
|
||||
}
|
||||
for (auto& arr : outputs) {
|
||||
if (auto it = deps.find(arr.id()); it != deps.end()) {
|
||||
it->second.wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotans) {
|
||||
// Make tracers from given primals
|
||||
std::vector<array> primals_;
|
||||
for (auto& p : primals) {
|
||||
auto s = p.has_primitive() ? p.primitive().stream()
|
||||
: default_stream(default_device());
|
||||
primals_.push_back(copy(p, s)); // Does not do a deep copy
|
||||
primals_.back().set_tracer(true);
|
||||
}
|
||||
|
||||
// Pass tracer primals through the function
|
||||
// Any variables that depend on the primals are marked as tracers
|
||||
auto outputs = fun(primals_);
|
||||
|
||||
// Map outputs to passed cotans while ignoring the outputs
|
||||
// that have stop_gradient called on them
|
||||
int cotan_index = 0;
|
||||
std::vector<std::pair<int, int>> output_cotan_pairs;
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
auto& out = outputs[i];
|
||||
if (out.has_primitive()) {
|
||||
if (auto& p = out.primitive(); typeid(p) == typeid(StopGradient)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (cotan_index >= cotans.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[vjp] Number of outputs with gradient does not match number of cotangents.");
|
||||
}
|
||||
if (out.shape() != cotans[cotan_index].shape()) {
|
||||
throw std::invalid_argument(
|
||||
"[vjp] Output shape does not match shape of cotangent.");
|
||||
}
|
||||
output_cotan_pairs.emplace_back(i, cotan_index++);
|
||||
}
|
||||
|
||||
// Topologically sort the compute graph, record outputs
|
||||
// in the tape if a gradient is needed.
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_set<std::uintptr_t> calc_grad;
|
||||
for (auto& primal : primals_) {
|
||||
primal.set_tracer(false);
|
||||
calc_grad.insert(primal.id());
|
||||
cache.insert(primal.id());
|
||||
}
|
||||
|
||||
std::vector<array> tape;
|
||||
|
||||
std::function<void(array&)> recurse;
|
||||
recurse = [&](auto& a) {
|
||||
auto id = a.id();
|
||||
a.set_tracer(false);
|
||||
|
||||
// Check if visited and add to cache if not
|
||||
if (auto inserted = cache.insert(id); !inserted.second) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto& input : a.editable_inputs()) {
|
||||
recurse(input);
|
||||
}
|
||||
|
||||
// Stop grad
|
||||
if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate gradient if any inputs require gradient
|
||||
for (auto& input : a.inputs()) {
|
||||
if (calc_grad.find(input.id()) != calc_grad.end()) {
|
||||
tape.push_back(a);
|
||||
calc_grad.insert(id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (auto& out : outputs) {
|
||||
recurse(out);
|
||||
}
|
||||
|
||||
// Run the tape backwards, computing vector-jacobian
|
||||
// products for each primitive
|
||||
std::unordered_map<std::uintptr_t, array> cotan_map;
|
||||
for (auto [out_idx, cotan_idx] : output_cotan_pairs) {
|
||||
cotan_map.insert({outputs[out_idx].id(), cotans[cotan_idx]});
|
||||
}
|
||||
for (auto it = tape.rbegin(); it != tape.rend(); ++it) {
|
||||
auto& a = *it;
|
||||
|
||||
// Get the arguments whose gradients are needed
|
||||
std::vector<int> argnums;
|
||||
for (int i = 0; i < a.inputs().size(); ++i) {
|
||||
if (calc_grad.find(a.inputs()[i].id()) != calc_grad.end()) {
|
||||
argnums.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
auto cotan_it = cotan_map.find(a.id());
|
||||
if (cotan_it == cotan_map.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cotan = cotan_map.extract(cotan_it).mapped();
|
||||
auto vjps = a.primitive().vjp(a.inputs(), cotan, argnums);
|
||||
auto s = a.primitive().stream();
|
||||
// Accumulate the vector-jacobian products for each input
|
||||
for (int i = 0; i < argnums.size(); ++i) {
|
||||
auto in_id = a.inputs()[argnums[i]].id();
|
||||
if (auto cotan_it = cotan_map.find(in_id); cotan_it != cotan_map.end()) {
|
||||
cotan_it->second = add(cotan_it->second, vjps[i], s);
|
||||
} else {
|
||||
cotan_map.insert({in_id, vjps[i]});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> vjps;
|
||||
for (auto& primal : primals_) {
|
||||
if (auto cotan_it = cotan_map.find(primal.id());
|
||||
cotan_it != cotan_map.end()) {
|
||||
vjps.push_back(cotan_it->second);
|
||||
} else {
|
||||
auto s = primal.has_primitive() ? primal.primitive().stream()
|
||||
: default_stream(default_device());
|
||||
vjps.push_back(zeros_like(primal, s));
|
||||
}
|
||||
}
|
||||
return {outputs, vjps};
|
||||
}
|
||||
|
||||
std::pair<array, array> vjp(
|
||||
const std::function<array(const array&)>& fun,
|
||||
const array& primal,
|
||||
const array& cotan) {
|
||||
auto vec_fun = [fun](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{fun(inputs[0])};
|
||||
};
|
||||
auto [outputs, vjps] = vjp(vec_fun, {primal}, {cotan});
|
||||
return {outputs[0], vjps[0]};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents) {
|
||||
if (primals.size() != tangents.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[jvp] Number of inputs does not match number of tangents.");
|
||||
}
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
if (primals[i].shape() != tangents[i].shape()) {
|
||||
throw std::invalid_argument(
|
||||
"[jvp] Input shape does not match shape of tangent.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> primals_;
|
||||
for (auto& p : primals) {
|
||||
auto s = p.has_primitive() ? p.primitive().stream()
|
||||
: default_stream(default_device());
|
||||
primals_.push_back(copy(p, s)); // Does not do a deep copy
|
||||
primals_.back().set_tracer(true);
|
||||
}
|
||||
auto outputs = fun(primals_);
|
||||
|
||||
// Topologically sort the compute graph, record outputs
|
||||
// in the tape if a gradient is needed.
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_set<std::uintptr_t> calc_grad;
|
||||
for (auto& primal : primals_) {
|
||||
primal.set_tracer(false);
|
||||
calc_grad.insert(primal.id());
|
||||
cache.insert(primal.id());
|
||||
}
|
||||
|
||||
std::vector<array> tape;
|
||||
|
||||
std::function<void(array&)> recurse;
|
||||
recurse = [&](auto& a) {
|
||||
auto id = a.id();
|
||||
a.set_tracer(false);
|
||||
|
||||
// Check if visited and add to cache if not
|
||||
if (auto inserted = cache.insert(id); !inserted.second) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto& input : a.editable_inputs()) {
|
||||
recurse(input);
|
||||
}
|
||||
|
||||
// Stop grad
|
||||
if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate gradient if any inputs require gradient
|
||||
for (auto& input : a.inputs()) {
|
||||
if (calc_grad.find(input.id()) != calc_grad.end()) {
|
||||
tape.push_back(a);
|
||||
calc_grad.insert(id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (auto& out : outputs) {
|
||||
recurse(out);
|
||||
}
|
||||
std::unordered_map<std::uintptr_t, array> tan_map;
|
||||
for (int i = 0; i < primals_.size(); ++i) {
|
||||
tan_map.insert({primals_[i].id(), tangents[i]});
|
||||
}
|
||||
|
||||
for (auto& a : tape) {
|
||||
// Get the arguments used in the jvp
|
||||
std::vector<int> argnums;
|
||||
std::vector<array> tangents;
|
||||
for (int i = 0; i < a.inputs().size(); ++i) {
|
||||
if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) {
|
||||
argnums.push_back(i);
|
||||
tangents.push_back(it->second);
|
||||
}
|
||||
}
|
||||
|
||||
auto jvp = a.primitive().jvp(a.inputs(), tangents, argnums);
|
||||
tan_map.insert({a.id(), jvp});
|
||||
}
|
||||
|
||||
std::vector<array> jvps;
|
||||
for (auto& out : outputs) {
|
||||
if (auto it = tan_map.find(out.id()); it != tan_map.end()) {
|
||||
jvps.push_back(it->second);
|
||||
} else {
|
||||
auto s = out.has_primitive() ? out.primitive().stream()
|
||||
: default_stream(default_device());
|
||||
jvps.push_back(zeros_like(out, s));
|
||||
}
|
||||
}
|
||||
return {outputs, jvps};
|
||||
}
|
||||
|
||||
std::pair<array, array> jvp(
|
||||
const std::function<array(const array&)>& fun,
|
||||
const array& primal,
|
||||
const array& tangent) {
|
||||
auto vec_fun = [fun](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{fun(inputs[0])};
|
||||
};
|
||||
auto [outputs, jvps] = jvp(vec_fun, {primal}, {tangent});
|
||||
return {outputs[0], jvps[0]};
|
||||
}
|
||||
|
||||
ValueAndGradFn value_and_grad(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<int>& argnums) {
|
||||
if (argnums.empty()) {
|
||||
throw std::invalid_argument("[grad] Must specify at least one argument.");
|
||||
}
|
||||
return [fun, argnums](const std::vector<array>& inputs) {
|
||||
std::set<int> args;
|
||||
for (auto& arg : argnums) {
|
||||
args.insert(arg < 0 ? arg + inputs.size() : arg);
|
||||
}
|
||||
if (args.size() != argnums.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[grad] Repeat argument number not allowed in grad.");
|
||||
}
|
||||
if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[grad] Invalid argument number for function with "
|
||||
<< inputs.size() << " inputs.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto gfun = [&fun, &inputs, &args](const std::vector<array>& ginputs) {
|
||||
std::vector<array> inputs_(inputs);
|
||||
auto argit = args.begin();
|
||||
for (int i = 0; i < ginputs.size(); ++i) {
|
||||
inputs_[*argit] = ginputs[i];
|
||||
++argit;
|
||||
}
|
||||
auto outputs = fun(inputs_);
|
||||
for (int i = 1; i < outputs.size(); i++) {
|
||||
auto& out = outputs[i];
|
||||
auto s = out.has_primitive() ? out.primitive().stream()
|
||||
: default_stream(default_device());
|
||||
outputs[i] = stop_gradient(out, s);
|
||||
}
|
||||
return outputs;
|
||||
};
|
||||
|
||||
std::vector<array> ginputs;
|
||||
for (auto arg : args) {
|
||||
ginputs.push_back(inputs[arg]);
|
||||
}
|
||||
// Set the incoming gradient as int32 so that it will be promoted to the
|
||||
// appropriate floating point type op(int, floatXX) -> floatXX for most ops
|
||||
auto [outputs, grads] = vjp(gfun, ginputs, {array(1)});
|
||||
return std::make_pair(outputs, grads);
|
||||
};
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& in_axes) {
|
||||
if (in_axes.size() != inputs.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] The number of in axes must match the number of inputs.");
|
||||
}
|
||||
|
||||
// Run the function on placeholder inputs
|
||||
// to get the original graph
|
||||
std::vector<array> s_inputs;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
if (in_axes[i] != -1) {
|
||||
if (inputs[i].ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] Cannot vmap an input with zero dimensions.");
|
||||
}
|
||||
if (in_axes[i] > inputs[i].ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[vmap] Axis " << in_axes[i] << " invalid for input with "
|
||||
<< inputs[i].ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
std::vector<int> shape = inputs[i].shape();
|
||||
shape.erase(shape.begin() + in_axes[i]);
|
||||
array in(shape, inputs[i].dtype(), nullptr, {});
|
||||
s_inputs.push_back(in);
|
||||
s_inputs.back().set_tracer(true);
|
||||
} else {
|
||||
s_inputs.push_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
return {s_inputs, fun(s_inputs)};
|
||||
}
|
||||
|
||||
std::vector<array> vmap_replace(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& s_inputs,
|
||||
const std::vector<array>& s_outputs,
|
||||
const std::vector<int>& in_axes,
|
||||
const std::vector<int>& out_axes) {
|
||||
if (out_axes.size() != s_outputs.size()) {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] The number of out axes must match the number of outputs.");
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
|
||||
std::unordered_set<std::uintptr_t> needs_vmap;
|
||||
for (int i = 0; i < s_inputs.size(); ++i) {
|
||||
if (in_axes[i] != -1) {
|
||||
tmap.insert({s_inputs[i].id(), {inputs[i], in_axes[i]}});
|
||||
needs_vmap.insert(s_inputs[i].id());
|
||||
}
|
||||
}
|
||||
|
||||
// Topologically sort the graph
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
for (int i = 0; i < s_inputs.size(); ++i) {
|
||||
auto in = s_inputs[i];
|
||||
if (in_axes[i] != -1) {
|
||||
in.set_tracer(false);
|
||||
}
|
||||
cache.insert(in.id());
|
||||
}
|
||||
std::vector<array> tape;
|
||||
|
||||
std::function<void(const array&)> recurse;
|
||||
|
||||
recurse = [&](const array& a) {
|
||||
// Stop at inputs to the vmap function
|
||||
auto id = a.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
for (auto& input : a.inputs()) {
|
||||
recurse(input);
|
||||
}
|
||||
cache.insert(id);
|
||||
for (auto& input : a.inputs()) {
|
||||
if (needs_vmap.find(input.id()) != needs_vmap.end()) {
|
||||
needs_vmap.insert(id);
|
||||
tape.push_back(a);
|
||||
tape.back().set_tracer(false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (auto& out : s_outputs) {
|
||||
recurse(out);
|
||||
}
|
||||
|
||||
// Transform each primitive in the graph with
|
||||
// its vmap implementation
|
||||
for (auto& a : tape) {
|
||||
std::vector<array> v_inputs;
|
||||
std::vector<int> v_axes;
|
||||
for (auto& in : a.inputs()) {
|
||||
auto map_it = tmap.find(in.id());
|
||||
if (map_it != tmap.end()) {
|
||||
v_inputs.push_back(map_it->second.first);
|
||||
v_axes.push_back(map_it->second.second);
|
||||
} else {
|
||||
v_inputs.push_back(in);
|
||||
v_axes.push_back(-1);
|
||||
}
|
||||
}
|
||||
auto out_and_axis = a.primitive().vmap(v_inputs, v_axes);
|
||||
tmap.insert({a.id(), out_and_axis});
|
||||
}
|
||||
|
||||
// Populate the outputs and make sure all the output axes are
|
||||
// in the right place
|
||||
std::vector<array> outputs;
|
||||
for (int i = 0; i < s_outputs.size(); ++i) {
|
||||
auto map_it = tmap.find(s_outputs[i].id());
|
||||
if (map_it != tmap.end()) {
|
||||
auto& [out, vdim] = map_it->second;
|
||||
if (vdim != out_axes[i]) {
|
||||
if (out_axes[i] >= out.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[vmap] Axis " << out_axes[i] << " invalid for output with "
|
||||
<< out.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
std::vector<int> reorder(out.ndim());
|
||||
std::iota(reorder.begin(), reorder.end(), 0);
|
||||
reorder.erase(reorder.begin() + vdim);
|
||||
reorder.insert(reorder.begin() + out_axes[i], vdim);
|
||||
out = transpose(out, reorder);
|
||||
}
|
||||
outputs.push_back(out);
|
||||
} else {
|
||||
outputs.push_back(s_outputs[i]);
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> vmap(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<int>& in_axes /* = {} */,
|
||||
const std::vector<int>& out_axes /* = {} */) {
|
||||
auto infer_axes = [](auto axes) {
|
||||
return !axes.empty() &&
|
||||
std::all_of(axes.begin(), axes.end(), [](int ax) { return ax < 0; });
|
||||
};
|
||||
if (infer_axes(in_axes) != infer_axes(out_axes)) {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] Input (or output) axes must be "
|
||||
"specified if output (or input) axes are.");
|
||||
}
|
||||
auto vfun = [fun, in_axes = in_axes, out_axes = out_axes](
|
||||
const std::vector<array>& inputs) mutable {
|
||||
if (in_axes.size() == 0) {
|
||||
in_axes.resize(inputs.size(), 0);
|
||||
}
|
||||
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
detail::vmap_trace(fun, inputs, in_axes);
|
||||
|
||||
if (out_axes.size() == 0) {
|
||||
out_axes.resize(trace_outputs.size(), 0);
|
||||
}
|
||||
|
||||
return detail::vmap_replace(
|
||||
inputs, trace_inputs, trace_outputs, in_axes, out_axes);
|
||||
};
|
||||
|
||||
return vfun;
|
||||
}
|
||||
|
||||
std::function<array(const array&, const array&)> vmap(
|
||||
const std::function<array(const array&, const array&)>& fun,
|
||||
int in_axis_a /* = 0 */,
|
||||
int in_axis_b /* = 0 */,
|
||||
int out_axis /* = 0 */) {
|
||||
auto vfun = vmap(
|
||||
[in_axis_a, in_axis_b, out_axis, fun](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{fun(inputs[0], inputs[1])};
|
||||
},
|
||||
{in_axis_a, in_axis_b},
|
||||
{out_axis});
|
||||
return [vfun](const array& a, const array& b) { return vfun({a, b})[0]; };
|
||||
}
|
||||
|
||||
std::function<array(const array&)> vmap(
|
||||
const std::function<array(const array&)>& fun,
|
||||
int in_axis /* = 0 */,
|
||||
int out_axis /* = 0 */) {
|
||||
auto vfun = vmap(
|
||||
[in_axis, out_axis, fun](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{fun(inputs[0])};
|
||||
},
|
||||
{in_axis},
|
||||
{out_axis});
|
||||
return [vfun](const array& a) { return vfun({a})[0]; };
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
185
mlx/transforms.h
Normal file
185
mlx/transforms.h
Normal file
@ -0,0 +1,185 @@
|
||||
#pragma once
|
||||
|
||||
#include "array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/** Fuse equivalent arrays to avoid duplicate execution. */
|
||||
void simplify(const std::vector<array>& outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
void simplify(Arrays... outputs) {
|
||||
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
void eval(const std::vector<array>& outputs, bool retain_graph = false);
|
||||
|
||||
template <typename... Arrays>
|
||||
void eval(Arrays... outputs) {
|
||||
eval(std::vector<array>{std::forward<Arrays>(outputs)...}, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the output and vector-Jacobian product (VJP) of a function.
|
||||
*
|
||||
* Computes the vector-Jacobian product of the vector of cotangents with the
|
||||
* Jacobian of the function evaluated at the primals. Returns a pair of
|
||||
* vectors of output arrays and VJP arrays.
|
||||
**/
|
||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents);
|
||||
|
||||
/**
|
||||
* Computes the output and vector-Jacobian product (VJP) of a unary function.
|
||||
*/
|
||||
std::pair<array, array> vjp(
|
||||
const std::function<array(const array&)>& fun,
|
||||
const array& primal,
|
||||
const array& cotangent);
|
||||
|
||||
/**
|
||||
* Computes the output and Jacobian-vector product (JVP) of a function.
|
||||
*
|
||||
* Computes the Jacobian-vector product of the Jacobian of the function
|
||||
* evaluated at the primals with the vector of tangents. Returns a pair of
|
||||
* vectors of output arrays and JVP arrays.
|
||||
**/
|
||||
std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents);
|
||||
|
||||
/**
|
||||
* Computes the output and Jacobian-vector product (JVP) of a unary function.
|
||||
*/
|
||||
std::pair<array, array> jvp(
|
||||
const std::function<array(const array&)>& fun,
|
||||
const array& primal,
|
||||
const array& tangent);
|
||||
|
||||
// Return type of general value_and_grad: a function which takes an input
|
||||
// vector of arrays and returns a pair of vectors of arrays one for the
|
||||
// values and one for the gradients wrt the first value.
|
||||
using ValueAndGradFn =
|
||||
std::function<std::pair<std::vector<array>, std::vector<array>>(
|
||||
const std::vector<array>&)>;
|
||||
using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
|
||||
const std::vector<array>&)>;
|
||||
|
||||
/**
|
||||
* Returns a function which computes the value and gradient of the input
|
||||
* function with respect to a vector of input arrays.
|
||||
**/
|
||||
ValueAndGradFn value_and_grad(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<int>& argnums);
|
||||
|
||||
/**
|
||||
* Returns a function which computes the value and gradient of the input
|
||||
* function with repsect to a single input array.
|
||||
**/
|
||||
ValueAndGradFn inline value_and_grad(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
int argnum = 0) {
|
||||
return value_and_grad(fun, std::vector<int>{argnum});
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a function which computes the value and gradient of the unary
|
||||
* input function.
|
||||
**/
|
||||
std::function<std::pair<array, array>(const array&)> inline value_and_grad(
|
||||
const std::function<array(const array&)>& fun) {
|
||||
return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
|
||||
}
|
||||
|
||||
SimpleValueAndGradFn inline value_and_grad(
|
||||
const std::function<array(const std::vector<array>&)>& fun,
|
||||
const std::vector<int>& argnums) {
|
||||
return [fun, argnums](auto inputs) {
|
||||
auto result = value_and_grad(
|
||||
[fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
|
||||
argnums)(inputs);
|
||||
|
||||
return std::make_pair(result.first[0], result.second);
|
||||
};
|
||||
}
|
||||
|
||||
SimpleValueAndGradFn inline value_and_grad(
|
||||
const std::function<array(const std::vector<array>&)>& fun,
|
||||
int argnum = 0) {
|
||||
return value_and_grad(fun, std::vector<int>{argnum});
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a function which computes the gradient of the input function with
|
||||
* respect to a vector of input arrays.
|
||||
*
|
||||
* The function being differentiated takes a vector of arrays and returns an
|
||||
* array. The vector of `argnums` specifies which the arguments to compute
|
||||
* the gradient with respect to. At least one argument must be specified.
|
||||
**/
|
||||
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
|
||||
const std::function<array(const std::vector<array>&)>& fun,
|
||||
const std::vector<int>& argnums) {
|
||||
auto fn = value_and_grad(fun, argnums);
|
||||
return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a function which computes the gradient of the input function with
|
||||
* repsect to a single input array.
|
||||
*
|
||||
* The function being differentiated takes a vector of arrays and returns an
|
||||
* array. The optional `argnum` index specifies which the argument to compute
|
||||
* the gradient with respect to and defaults to 0.
|
||||
**/
|
||||
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
|
||||
const std::function<array(const std::vector<array>&)>& fun,
|
||||
int argnum = 0) {
|
||||
return grad(fun, std::vector<int>{argnum});
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a function which computes the gradient of the unary input function.
|
||||
**/
|
||||
std::function<array(const array&)> inline grad(
|
||||
const std::function<array(const array&)>& fun) {
|
||||
auto fn = value_and_grad(fun);
|
||||
return [fn](const array& input) { return fn(input).second; };
|
||||
}
|
||||
|
||||
/**
|
||||
* Automatically vectorize a unary function over the requested axes.
|
||||
*/
|
||||
std::function<array(const array&)> vmap(
|
||||
const std::function<array(const array&)>& fun,
|
||||
int in_axis = 0,
|
||||
int out_axis = 0);
|
||||
|
||||
/**
|
||||
* Automatically vectorize a binary function over the requested axes.
|
||||
*/
|
||||
std::function<array(const array&, const array&)> vmap(
|
||||
const std::function<array(const array&, const array&)>& fun,
|
||||
int in_axis_a = 0,
|
||||
int in_axis_b = 0,
|
||||
int out_axis = 0);
|
||||
|
||||
/**
|
||||
* Automatically vectorize a function over the requested axes.
|
||||
*
|
||||
* The input function to `vmap` takes as an argument a vector of arrays and
|
||||
* returns a vector of arrays. Optionally specify the axes to vectorize over
|
||||
* with `in_axes` and `out_axes`, otherwise a default of 0 is used.
|
||||
* Returns a vectorized function with the same signature as the input
|
||||
* function.
|
||||
*/
|
||||
std::function<std::vector<array>(const std::vector<array>&)> vmap(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<int>& in_axes = {},
|
||||
const std::vector<int>& out_axes = {});
|
||||
|
||||
} // namespace mlx::core
|
185
mlx/types/bf16.h
Normal file
185
mlx/types/bf16.h
Normal file
@ -0,0 +1,185 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#define __MLX_BFLOAT_NAN__ 0x7FC0
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
union float_bits_bf16 {
|
||||
float f;
|
||||
uint32_t u;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
struct _MLX_BFloat16 {
|
||||
uint16_t bits_;
|
||||
|
||||
// Default constructor
|
||||
_MLX_BFloat16() = default;
|
||||
|
||||
// Default copy constructor
|
||||
_MLX_BFloat16(_MLX_BFloat16 const&) = default;
|
||||
|
||||
// Appease std::vector<bool> for being special
|
||||
_MLX_BFloat16& operator=(std::vector<bool>::reference x) {
|
||||
bits_ = x;
|
||||
return *this;
|
||||
}
|
||||
|
||||
_MLX_BFloat16& operator=(const float& x) {
|
||||
return (*this = _MLX_BFloat16(x));
|
||||
}
|
||||
|
||||
// From float32
|
||||
_MLX_BFloat16(const float& x) {
|
||||
if (std::isnan(x)) {
|
||||
bits_ = __MLX_BFLOAT_NAN__;
|
||||
} else {
|
||||
// Union
|
||||
float_bits_bf16 in;
|
||||
|
||||
// Take bits
|
||||
in.f = x;
|
||||
|
||||
// Round to nearest even
|
||||
in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF);
|
||||
|
||||
// Take upper 16 bits
|
||||
bits_ = in.u >> 16;
|
||||
}
|
||||
}
|
||||
|
||||
// To float32
|
||||
operator float() const {
|
||||
// Union
|
||||
float_bits_bf16 out;
|
||||
|
||||
// Upper 16 bits are the data and lower 16 bits are 0s
|
||||
out.u = ((uint32_t)bits_) << 16;
|
||||
|
||||
return out.f;
|
||||
}
|
||||
};
|
||||
|
||||
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
||||
inline otype __operator__(atype lhs, btype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
||||
inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
} \
|
||||
inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
// Operators
|
||||
#define bfloat_binop(_op_, _operator_) \
|
||||
bfloat_binop_base( \
|
||||
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, double, double, double); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
||||
|
||||
bfloat_binop(+, operator+);
|
||||
bfloat_binop(-, operator-);
|
||||
bfloat_binop(*, operator*);
|
||||
bfloat_binop(/, operator/);
|
||||
|
||||
#undef bfloat_binop
|
||||
|
||||
// Comparison ops
|
||||
#define bfloat_compop(__op__, __operator__) \
|
||||
bfloat_binop_base( \
|
||||
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, double, double); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
||||
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
||||
|
||||
bfloat_compop(>, operator>);
|
||||
bfloat_compop(<, operator<);
|
||||
bfloat_compop(>=, operator>=);
|
||||
bfloat_compop(<=, operator<=);
|
||||
bfloat_compop(==, operator==);
|
||||
bfloat_compop(!=, operator!=);
|
||||
|
||||
#undef bfloat_compop
|
||||
|
||||
// Negative
|
||||
inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) {
|
||||
return -static_cast<float>(lhs);
|
||||
}
|
||||
|
||||
// Inplace ops
|
||||
#define bfloat_inplace_op(__op__, __operator__) \
|
||||
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \
|
||||
lhs = lhs __op__ rhs; \
|
||||
return lhs; \
|
||||
} \
|
||||
inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs = lhs __op__ rhs; \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
bfloat_inplace_op(+, operator+=);
|
||||
bfloat_inplace_op(-, operator-=);
|
||||
bfloat_inplace_op(*, operator*=);
|
||||
bfloat_inplace_op(/, operator/=);
|
||||
|
||||
#undef bfloat_inplace_op
|
||||
|
||||
// Bitwise ops
|
||||
|
||||
#define bfloat_bitop(__op__, __operator__) \
|
||||
inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \
|
||||
_MLX_BFloat16 out; \
|
||||
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
||||
return out; \
|
||||
} \
|
||||
inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \
|
||||
_MLX_BFloat16 out; \
|
||||
out.bits_ = lhs.bits_ __op__ rhs; \
|
||||
return out; \
|
||||
} \
|
||||
inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \
|
||||
_MLX_BFloat16 out; \
|
||||
out.bits_ = lhs __op__ rhs.bits_; \
|
||||
return out; \
|
||||
}
|
||||
|
||||
bfloat_bitop(|, operator|);
|
||||
bfloat_bitop(&, operator&);
|
||||
bfloat_bitop(^, operator^);
|
||||
|
||||
#undef bfloat_bitop
|
||||
|
||||
#define bfloat_inplace_bitop(__op__, __operator__) \
|
||||
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
||||
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
||||
return lhs; \
|
||||
} \
|
||||
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \
|
||||
lhs.bits_ = lhs.bits_ __op__ rhs; \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
bfloat_inplace_bitop(|, operator|=);
|
||||
bfloat_inplace_bitop(&, operator&=);
|
||||
bfloat_inplace_bitop(^, operator^=);
|
||||
|
||||
#undef bfloat_inplace_bitop
|
||||
|
||||
} // namespace mlx::core
|
3
python/mlx/nn/__init__.py
Normal file
3
python/mlx/nn/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from mlx.nn.layers import *
|
||||
from mlx.nn import losses
|
||||
from mlx.nn.utils import value_and_grad
|
23
python/mlx/nn/layers/__init__.py
Normal file
23
python/mlx/nn/layers/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.activations import (
|
||||
GELU,
|
||||
ReLU,
|
||||
SiLU,
|
||||
gelu,
|
||||
gelu_approx,
|
||||
gelu_fast_approx,
|
||||
relu,
|
||||
silu,
|
||||
)
|
||||
from mlx.nn.layers.containers import Sequential
|
||||
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||
from mlx.nn.layers.dropout import Dropout
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
|
||||
from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding
|
||||
from mlx.nn.layers.transformer import (
|
||||
MultiHeadAttention,
|
||||
TransformerEncoder,
|
||||
TransformerEncoderLayer,
|
||||
)
|
129
python/mlx/nn/layers/activations.py
Normal file
129
python/mlx/nn/layers/activations.py
Normal file
@ -0,0 +1,129 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
def _make_activation_module(f):
|
||||
def decorator(klass):
|
||||
klass.__doc__ = f.__doc__
|
||||
klass.__call__ = lambda self, x: f(x)
|
||||
return klass
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def relu(x):
|
||||
"""Applies the Rectified Linear Unit.
|
||||
|
||||
Simply ``mx.maximum(x, 0)``.
|
||||
"""
|
||||
return mx.maximum(x, 0)
|
||||
|
||||
|
||||
def silu(x):
|
||||
r"""Applies the Sigmoid Linear Unit.
|
||||
|
||||
Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is
|
||||
the logistic sigmoid.
|
||||
"""
|
||||
return x * mx.sigmoid(x)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Applies the Gaussian Error Linear Units function.
|
||||
|
||||
.. math::
|
||||
\\textrm{GELU}(x) = x * \Phi(x)
|
||||
|
||||
where :math:`\Phi(x)` is the Gaussian CDF.
|
||||
|
||||
See also :func:`gelu_approx` and :func:`gelu_fast_approx` for faster
|
||||
approximations.
|
||||
"""
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
|
||||
def gelu_approx(x):
|
||||
r"""An approximation to Gaussian Error Linear Unit.
|
||||
|
||||
See :func:`gelu` for the exact computation.
|
||||
|
||||
This function approximates ``gelu`` with a maximum absolute error :math:`<
|
||||
0.0003` in the range :math:`[-6, 6]` using the following
|
||||
|
||||
.. math::
|
||||
|
||||
x = x \sigma\left(1.60033 x \left(1 + 0.0433603 x^2\right)\right)
|
||||
|
||||
where :math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||
"""
|
||||
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
|
||||
|
||||
|
||||
def gelu_fast_approx(x):
|
||||
r"""A fast approximation to Gaussian Error Linear Unit.
|
||||
|
||||
See :func:`gelu` for the exact computation.
|
||||
|
||||
This function approximates ``gelu`` with a maximum absolute error :math:`<
|
||||
0.015` in the range :math:`[-6, 6]` using the following
|
||||
|
||||
.. math::
|
||||
|
||||
x = x \sigma\left(1.773 x\right)
|
||||
|
||||
where :math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||
"""
|
||||
return x * mx.sigmoid(1.773 * x)
|
||||
|
||||
|
||||
@_make_activation_module(relu)
|
||||
class ReLU(Module):
|
||||
pass
|
||||
|
||||
|
||||
@_make_activation_module(silu)
|
||||
class SiLU(Module):
|
||||
pass
|
||||
|
||||
|
||||
class GELU(Module):
|
||||
r"""Applies the Gaussian Error Linear Units.
|
||||
|
||||
.. math::
|
||||
\textrm{GELU}(x) = x * \Phi(x)
|
||||
|
||||
where :math:`\Phi(x)` is the Gaussian CDF.
|
||||
|
||||
However, if ``approx`` is set to 'precise' or 'fast' it applies
|
||||
|
||||
.. math::
|
||||
\textrm{GELUApprox}(x) &= x * \sigma\left(1.60033 * x \left(1 + 0.0433603 * x^2\right)\right) \\
|
||||
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
|
||||
|
||||
respectively.
|
||||
|
||||
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
|
||||
functional equivalents and information regarding error bounds.
|
||||
|
||||
Args:
|
||||
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
|
||||
"""
|
||||
|
||||
def __init__(self, approx="none"):
|
||||
super().__init__()
|
||||
|
||||
if approx == "none":
|
||||
self._act = gelu
|
||||
elif approx == "precise":
|
||||
self._act = gelu_approx
|
||||
elif approx == "fast":
|
||||
self._act = gelu_fast_approx
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return self._act(x)
|
6
python/mlx/nn/losses.py
Normal file
6
python/mlx/nn/losses.py
Normal file
@ -0,0 +1,6 @@
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1):
|
||||
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
|
||||
return mx.logsumexp(logits, axis=axis) - score
|
136
python/mlx/utils.py
Normal file
136
python/mlx/utils.py
Normal file
@ -0,0 +1,136 @@
|
||||
def tree_map(fn, tree, *rest):
|
||||
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
|
||||
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
|
||||
and the corresponding leaves are provided as extra positional arguments to
|
||||
``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
|
||||
than to :func:`map`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map
|
||||
|
||||
model = nn.Linear(10, 10)
|
||||
print(model.parameters().keys())
|
||||
# dict_keys(['weight', 'bias'])
|
||||
|
||||
# square the parameters
|
||||
model.update(tree_map(lambda x: x*x, model.parameters()))
|
||||
|
||||
Args:
|
||||
fn (Callable): The function that processes the leaves of the tree
|
||||
tree (Any): The main python tree that will be iterated upon
|
||||
rest (Tuple[Any]): Extra trees to be iterated together with tree
|
||||
|
||||
Returns:
|
||||
A python tree with the new values returned by ``fn``.
|
||||
"""
|
||||
if isinstance(tree, list):
|
||||
return [
|
||||
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
|
||||
]
|
||||
elif isinstance(tree, tuple):
|
||||
return tuple(
|
||||
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
|
||||
)
|
||||
elif isinstance(tree, dict):
|
||||
return {
|
||||
k: tree_map(fn, child, *(r[k] for r in rest)) for k, child in tree.items()
|
||||
}
|
||||
else:
|
||||
return fn(tree, *rest)
|
||||
|
||||
|
||||
def tree_flatten(tree, prefix="", is_leaf=None):
|
||||
"""Flattens a python tree to a list of key, value tuples.
|
||||
|
||||
The keys are using the dot notation to define trees of arbitrary depth and
|
||||
complexity.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
print(tree_flatten([[[0]]]))
|
||||
# [("0.0.0", 0)]
|
||||
|
||||
print(tree_flatten([[[0]]], ".hello"))
|
||||
# [("hello.0.0.0", 0)]
|
||||
|
||||
.. note::
|
||||
Dictionaries should have keys that are valid python identifiers.
|
||||
|
||||
Args:
|
||||
tree (Any): The python tree to be flattened.
|
||||
prefix (str): A prefix to use for the keys. The first character is
|
||||
always discarded.
|
||||
is_leaf (Callable): An optional callable that returns True if the
|
||||
passed object is considered a leaf or False otherwise.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Any]]: The flat representation of the python tree.
|
||||
"""
|
||||
flat_tree = []
|
||||
|
||||
if is_leaf is None or not is_leaf(tree):
|
||||
if isinstance(tree, (list, tuple)):
|
||||
for i, t in enumerate(tree):
|
||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
|
||||
return flat_tree
|
||||
if isinstance(tree, dict):
|
||||
for k, t in tree.items():
|
||||
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
|
||||
return flat_tree
|
||||
|
||||
return [(prefix[1:], tree)]
|
||||
|
||||
|
||||
def tree_unflatten(tree):
|
||||
"""Recreate a python tree from its flat representation.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
d = tree_unflatten([("hello.world", 42)])
|
||||
print(d)
|
||||
# {"hello": {"world": 42}}
|
||||
|
||||
Args:
|
||||
tree (List[Tuple[str, Any]]): The flat representation of a python tree.
|
||||
For instance as returned by :meth:`tree_flatten`.
|
||||
|
||||
Returns:
|
||||
A python tree.
|
||||
"""
|
||||
if len(tree) == 1 and tree[0][0] == "":
|
||||
return tree[0][1]
|
||||
|
||||
try:
|
||||
int(tree[0][0].split(".", maxsplit=1)[0])
|
||||
is_list = True
|
||||
except ValueError:
|
||||
is_list = False
|
||||
|
||||
# collect children
|
||||
children = {}
|
||||
for key, value in tree:
|
||||
current_idx, *next_idx = key.split(".", maxsplit=1)
|
||||
next_idx = "" if not next_idx else next_idx[0]
|
||||
if current_idx not in children:
|
||||
children[current_idx] = []
|
||||
children[current_idx].append((next_idx, value))
|
||||
|
||||
# recursively map them to the original container
|
||||
if is_list:
|
||||
keys = sorted((int(idx), idx) for idx in children.keys())
|
||||
l = []
|
||||
for i, k in keys:
|
||||
while i > len(l):
|
||||
l.append({})
|
||||
l.append(tree_unflatten(children[k]))
|
||||
return l
|
||||
else:
|
||||
return {k: tree_unflatten(v) for k, v in children.items()}
|
1071
python/src/array.cpp
Normal file
1071
python/src/array.cpp
Normal file
File diff suppressed because it is too large
Load Diff
42
python/src/device.cpp
Normal file
42
python/src/device.cpp
Normal file
@ -0,0 +1,42 @@
|
||||
#include <sstream>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_device(py::module_& m) {
|
||||
py::enum_<Device::DeviceType>(m, "DeviceType")
|
||||
.value("cpu", Device::DeviceType::cpu)
|
||||
.value("gpu", Device::DeviceType::gpu)
|
||||
.export_values()
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const Device::DeviceType& d1, const Device& d2) {
|
||||
return d1 == d2;
|
||||
},
|
||||
py::prepend());
|
||||
|
||||
py::class_<Device>(m, "Device")
|
||||
.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
||||
.def_readonly("type", &Device::type)
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const Device& d) {
|
||||
std::ostringstream os;
|
||||
os << d;
|
||||
return os.str();
|
||||
})
|
||||
.def("__eq__", [](const Device& d1, const Device& d2) {
|
||||
return d1 == d2;
|
||||
});
|
||||
|
||||
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||
|
||||
m.def("default_device", &default_device);
|
||||
m.def("set_default_device", &set_default_device, "device"_a);
|
||||
}
|
12
python/src/metal.cpp
Normal file
12
python/src/metal.cpp
Normal file
@ -0,0 +1,12 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_metal(py::module_& m) {
|
||||
py::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||
metal.def("is_available", &metal::is_available);
|
||||
}
|
32
python/src/stream.cpp
Normal file
32
python/src/stream.cpp
Normal file
@ -0,0 +1,32 @@
|
||||
#include <sstream>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_stream(py::module_& m) {
|
||||
py::class_<Stream>(m, "Stream")
|
||||
.def(py::init<int, Device>(), "index"_a, "device"_a)
|
||||
.def_readonly("device", &Stream::device)
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const Stream& s) {
|
||||
std::ostringstream os;
|
||||
os << s;
|
||||
return os.str();
|
||||
})
|
||||
.def("__eq__", [](const Stream& s1, const Stream& s2) {
|
||||
return s1 == s2;
|
||||
});
|
||||
|
||||
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||
|
||||
m.def("default_stream", &default_stream, "device"_a);
|
||||
m.def("set_default_stream", &set_default_stream, "stream"_a);
|
||||
m.def("new_stream", &new_stream, "device"_a);
|
||||
}
|
188
python/tests/test_bf16.py
Normal file
188
python/tests/test_bf16.py
Normal file
@ -0,0 +1,188 @@
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
|
||||
import math
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
import mlx_tests
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
has_torch = True
|
||||
except ImportError as e:
|
||||
has_torch = False
|
||||
|
||||
|
||||
class TestBF16(mlx_tests.MLXTestCase):
|
||||
def __test_ops(
|
||||
self,
|
||||
ref_op, # Function that outputs array_like
|
||||
mlx_op, # Function that outputs array_like
|
||||
np_args, # Numpy arguments
|
||||
ref_transform=lambda x: x,
|
||||
mlx_transform=lambda x: mx.array(x),
|
||||
atol=1e-5,
|
||||
):
|
||||
ref_args = map(ref_transform, np_args)
|
||||
mlx_args = map(mlx_transform, np_args)
|
||||
|
||||
r_ref = ref_op(*ref_args)
|
||||
r_mlx = mlx_op(*mlx_args)
|
||||
|
||||
self.assertTrue(np.allclose(r_mlx, r_ref, atol=atol))
|
||||
|
||||
def __default_test(
|
||||
self,
|
||||
op,
|
||||
np_args,
|
||||
simple_transform=lambda x: x,
|
||||
atol_np=1e-3,
|
||||
atol_torch=1e-5,
|
||||
np_kwargs=dict(),
|
||||
mlx_kwargs=dict(),
|
||||
torch_kwargs=dict(),
|
||||
torch_op=None,
|
||||
):
|
||||
with self.subTest(reference="numpy"):
|
||||
|
||||
def np_transform(x):
|
||||
x_mx_bf16 = mx.array(x).astype(mx.bfloat16)
|
||||
x_mx_fp32 = x_mx_bf16.astype(mx.float32)
|
||||
return np.asarray(x_mx_fp32)
|
||||
|
||||
def mlx_fn(*args):
|
||||
out_bf16 = getattr(mx, op)(*args, **mlx_kwargs)
|
||||
return np.asarray(out_bf16.astype(mx.float32))
|
||||
|
||||
def np_fn(*args):
|
||||
out_fp32 = getattr(np, op)(*args, **np_kwargs)
|
||||
return np_transform(out_fp32)
|
||||
|
||||
ref_op = np_fn
|
||||
mlx_op = mlx_fn
|
||||
|
||||
ref_transform = lambda x: simple_transform(np_transform(x))
|
||||
mlx_transform = lambda x: simple_transform(mx.array(x).astype(mx.bfloat16))
|
||||
|
||||
self.__test_ops(
|
||||
ref_op,
|
||||
mlx_op,
|
||||
np_args,
|
||||
ref_transform=ref_transform,
|
||||
mlx_transform=mlx_transform,
|
||||
atol=atol_np,
|
||||
)
|
||||
|
||||
if has_torch:
|
||||
with self.subTest(reference="torch"):
|
||||
torch_op = op if torch_op is None else torch_op
|
||||
|
||||
def torch_fn(*args):
|
||||
out_bf16 = getattr(torch, torch_op)(*args, **torch_kwargs)
|
||||
return out_bf16.to(torch.float32).numpy()
|
||||
|
||||
ref_op = torch_fn
|
||||
ref_transform = lambda x: simple_transform(
|
||||
torch.from_numpy(x).to(torch.bfloat16)
|
||||
)
|
||||
self.__test_ops(
|
||||
ref_op,
|
||||
mlx_op,
|
||||
np_args,
|
||||
ref_transform=ref_transform,
|
||||
mlx_transform=mlx_transform,
|
||||
atol=atol_torch,
|
||||
)
|
||||
|
||||
def test_unary_ops(self):
|
||||
x = np.random.rand(18, 28, 38)
|
||||
for op in ["abs", "exp", "log", "square", "sqrt"]:
|
||||
with self.subTest(op=op):
|
||||
np_args = (x.astype(np.float32),)
|
||||
self.__default_test(op, np_args)
|
||||
|
||||
def test_binary_ops(self):
|
||||
x = np.random.rand(18, 28, 38)
|
||||
y = np.random.rand(18, 28, 38)
|
||||
for op in ["add", "subtract", "multiply", "divide", "maximum", "minimum"]:
|
||||
with self.subTest(op=op):
|
||||
np_args = (
|
||||
x.astype(np.float32),
|
||||
y.astype(np.float32),
|
||||
)
|
||||
self.__default_test(op, np_args, simple_transform=lambda x: x)
|
||||
self.__default_test(op, np_args, simple_transform=lambda x: x[:1])
|
||||
self.__default_test(op, np_args, simple_transform=lambda x: x[:, :1])
|
||||
|
||||
def test_reduction_ops(self):
|
||||
x = np.random.rand(18, 28, 38).astype(np.float32)
|
||||
|
||||
for op in ("min", "max"):
|
||||
with self.subTest(op=op):
|
||||
|
||||
for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
|
||||
with self.subTest(axes=axes):
|
||||
np_args = (x.astype(np.float32),)
|
||||
self.__default_test(
|
||||
op,
|
||||
np_args,
|
||||
np_kwargs={"axis": axes},
|
||||
mlx_kwargs={"axis": axes},
|
||||
torch_kwargs={"dim": axes},
|
||||
torch_op="a" + op,
|
||||
)
|
||||
|
||||
def test_arg_reduction_ops(self):
|
||||
data = np.random.rand(10, 12, 13).astype(np.float32)
|
||||
x = mx.array(data).astype(mx.bfloat16)
|
||||
data = np.asarray(x.astype(mx.float32))
|
||||
|
||||
for op in ["argmin", "argmax"]:
|
||||
for axis in range(3):
|
||||
for kd in [True, False]:
|
||||
a = getattr(mx, op)(x, axis, kd)
|
||||
b = getattr(np, op)(data, axis, keepdims=kd)
|
||||
a = a.astype(mx.float32)
|
||||
self.assertEqual(a.tolist(), b.tolist())
|
||||
|
||||
for op in ["argmin", "argmax"]:
|
||||
a = getattr(mx, op)(x, keepdims=True)
|
||||
b = getattr(np, op)(data, keepdims=True)
|
||||
a = a.astype(mx.float32)
|
||||
self.assertEqual(a.tolist(), b.tolist())
|
||||
a = getattr(mx, op)(x)
|
||||
b = getattr(np, op)(data)
|
||||
a = a.astype(mx.float32)
|
||||
self.assertEqual(a.item(), b)
|
||||
|
||||
def test_blas_ops(self):
|
||||
if mx.default_device() != mx.gpu:
|
||||
return
|
||||
|
||||
def test_blas(shape_x, shape_y):
|
||||
np.random.seed(42)
|
||||
with self.subTest(shape_x=shape_x, shape_y=shape_y):
|
||||
x = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_x)
|
||||
y = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_y)
|
||||
|
||||
np_args = (
|
||||
x.astype(np.float32),
|
||||
y.astype(np.float32),
|
||||
)
|
||||
op = "matmul"
|
||||
|
||||
self.__default_test(op, np_args, atol_np=1e-3, atol_torch=1e-3)
|
||||
|
||||
for shape_x, shape_y in [
|
||||
[(32, 32), (32, 32)],
|
||||
[(23, 57), (57, 1)],
|
||||
[(1, 3), (3, 128)],
|
||||
[(8, 128, 768), (768, 16)],
|
||||
]:
|
||||
test_blas(shape_x, shape_y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
224
tests/creations_tests.cpp
Normal file
224
tests/creations_tests.cpp
Normal file
@ -0,0 +1,224 @@
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test arange") {
|
||||
// Check type is inferred correclty
|
||||
{
|
||||
auto x = arange(10);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
x = arange(10.0);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = arange(10, float32);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = arange(10.0, int32);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
x = arange(0, 10);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
x = arange(0.0, 10.0, int32);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
x = arange(0.0, 10.0);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = arange(0, 10, float32);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = arange(0, 10, 0.1, float32);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = arange(0.0, 10.0, 0.5, int32);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
x = arange(10.0, uint32);
|
||||
CHECK_EQ(x.dtype(), uint32);
|
||||
x = arange(0.0, 10.0, uint32);
|
||||
CHECK_EQ(x.dtype(), uint32);
|
||||
x = arange(0.0, 10.0, 0.5, uint32);
|
||||
CHECK_EQ(x.dtype(), uint32);
|
||||
|
||||
// arange unsupported for bool_
|
||||
CHECK_THROWS_AS(arange(10, bool_), std::invalid_argument);
|
||||
}
|
||||
|
||||
// Check correct sizes
|
||||
{
|
||||
auto x = arange(10);
|
||||
CHECK_EQ(x.size(), 10);
|
||||
|
||||
x = arange(0.0, 10.0, 0.5);
|
||||
CHECK_EQ(x.size(), 20);
|
||||
|
||||
x = arange(0.0, 10.0, 0.45);
|
||||
CHECK_EQ(x.size(), 23);
|
||||
|
||||
x = arange(0, 10, 10);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
|
||||
x = arange(0, 10, 9);
|
||||
CHECK_EQ(x.size(), 2);
|
||||
|
||||
x = arange(0, 10, 100);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
|
||||
x = arange(0, -10, 1);
|
||||
CHECK_EQ(x.size(), 0);
|
||||
|
||||
x = arange(0, -10, -1);
|
||||
CHECK_EQ(x.size(), 10);
|
||||
|
||||
x = arange(0, -10, -10);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
}
|
||||
|
||||
// Check values
|
||||
{
|
||||
auto x = arange(0, 3);
|
||||
CHECK(array_equal(x, array({0, 1, 2})).item<bool>());
|
||||
|
||||
x = arange(0, 3, 2);
|
||||
CHECK(array_equal(x, array({0, 2})).item<bool>());
|
||||
|
||||
x = arange(0, 3, 3);
|
||||
CHECK(array_equal(x, array({0})).item<bool>());
|
||||
|
||||
x = arange(0, -3, 1);
|
||||
CHECK(array_equal(x, array({})).item<bool>());
|
||||
|
||||
x = arange(0, 3, -1);
|
||||
CHECK(array_equal(x, array({})).item<bool>());
|
||||
|
||||
x = arange(0, -3, -1);
|
||||
CHECK(array_equal(x, array({0, -1, -2})).item<bool>());
|
||||
|
||||
x = arange(0.0, 5.0, 0.5, int32);
|
||||
CHECK(array_equal(x, zeros({10})).item<bool>());
|
||||
|
||||
x = arange(0.0, 5.0, 1.5, int32);
|
||||
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test astype") {
|
||||
// Check type conversions
|
||||
{
|
||||
auto x = array(1);
|
||||
auto y = astype(x, float32);
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
CHECK_EQ(y.item<float>(), 1.0f);
|
||||
|
||||
y = astype(x, int32);
|
||||
CHECK_EQ(y.dtype(), int32);
|
||||
CHECK_EQ(y.item<int>(), 1);
|
||||
|
||||
x = array(-3.0f);
|
||||
y = astype(x, int32);
|
||||
CHECK_EQ(y.dtype(), int32);
|
||||
CHECK_EQ(y.item<int>(), -3);
|
||||
|
||||
y = astype(x, uint32);
|
||||
CHECK_EQ(y.dtype(), uint32);
|
||||
|
||||
// Use std::copy since the result is platform dependent
|
||||
uint32_t v;
|
||||
std::copy(x.data<float>(), x.data<float>() + 1, &v);
|
||||
CHECK_EQ(y.item<uint32_t>(), v);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test full") {
|
||||
// Check full works for different types
|
||||
{
|
||||
auto x = full({}, 0);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
CHECK_EQ(x.item<int>(), 0);
|
||||
|
||||
x = full({}, 0.0);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK_EQ(x.item<float>(), 0);
|
||||
|
||||
x = full({}, false);
|
||||
CHECK_EQ(x.item<bool>(), false);
|
||||
|
||||
x = full({}, 0, int32);
|
||||
CHECK_EQ(x.item<int>(), 0);
|
||||
|
||||
x = full({}, 0, float32);
|
||||
CHECK_EQ(x.item<float>(), 0);
|
||||
|
||||
x = full({1, 2}, 2, float32);
|
||||
CHECK(array_equal(x, array({2.0, 2.0}, {1, 2})).item<bool>());
|
||||
|
||||
x = full({2, 1}, 2, float32);
|
||||
CHECK(array_equal(x, array({2.0, 2.0}, {2, 1})).item<bool>());
|
||||
|
||||
x = full({2}, false);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
CHECK(array_equal(x, array({false, false})).item<bool>());
|
||||
|
||||
x = full({2}, 1.0, bool_);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
CHECK(array_equal(x, array({true, true})).item<bool>());
|
||||
|
||||
x = full({2}, 1.0, uint32);
|
||||
CHECK_EQ(x.dtype(), uint32);
|
||||
CHECK(array_equal(x, array({1, 1})).item<bool>());
|
||||
|
||||
CHECK_THROWS_AS(full({2}, array({})), std::invalid_argument);
|
||||
}
|
||||
|
||||
// Check broadcasting works
|
||||
{
|
||||
auto x = full({2, 2}, array({3, 4}, {2, 1}));
|
||||
CHECK(array_equal(x, array({3, 3, 4, 4}, {2, 2})).item<bool>());
|
||||
x = full({2, 2}, array({3, 4}, {1, 2}));
|
||||
CHECK(array_equal(x, array({3, 4, 3, 4}, {2, 2})).item<bool>());
|
||||
}
|
||||
|
||||
// Check zeros and ones
|
||||
{
|
||||
auto x = zeros({2, 2}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
|
||||
CHECK_EQ(x.ndim(), 2);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2});
|
||||
CHECK(array_equal(x, y).item<bool>());
|
||||
|
||||
x = ones({2, 2}, float32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
|
||||
CHECK_EQ(x.ndim(), 2);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
y = array({1.0, 1.0, 1.0, 1.0}, {2, 2});
|
||||
CHECK(array_equal(x, y).item<bool>());
|
||||
|
||||
x = zeros({2, 2}, int32);
|
||||
y = zeros_like(x);
|
||||
CHECK_EQ(y.dtype(), int32);
|
||||
CHECK(array_equal(x, y).item<bool>());
|
||||
|
||||
x = ones({2, 2}, int32);
|
||||
y = ones_like(x);
|
||||
CHECK_EQ(y.dtype(), int32);
|
||||
CHECK(array_equal(x, y).item<bool>());
|
||||
}
|
||||
|
||||
// Works for empty shape and empty array
|
||||
{
|
||||
array x = ones({}, int32);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{});
|
||||
CHECK_EQ(x.item<int>(), 1);
|
||||
|
||||
x = full({0}, array({}));
|
||||
CHECK_EQ(x.shape(), std::vector<int>{0});
|
||||
CHECK_EQ(x.size(), 0);
|
||||
|
||||
CHECK_THROWS_AS(full({}, array({})), std::invalid_argument);
|
||||
}
|
||||
}
|
331
tests/fft_tests.cpp
Normal file
331
tests/fft_tests.cpp
Normal file
@ -0,0 +1,331 @@
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test fft basics") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
array x(1.0);
|
||||
CHECK_THROWS(fft::fft(x));
|
||||
CHECK_THROWS(fft::ifft(x));
|
||||
|
||||
x = array({1.0});
|
||||
auto y = fft::fft(x);
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});
|
||||
|
||||
y = fft::ifft(x);
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});
|
||||
|
||||
x = array({complex64_t{1.0f, 1.0f}}, complex64);
|
||||
y = fft::fft(x);
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 1.0f});
|
||||
|
||||
y = fft::ifft(x);
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 1.0f});
|
||||
|
||||
{
|
||||
x = array({0.0f, 1.0f, 2.0f, 3.0f});
|
||||
y = fft::fft(x);
|
||||
std::initializer_list<complex64_t> expected = {
|
||||
{6.0, 0.0},
|
||||
{-2.0, 2.0},
|
||||
{-2.0, 0.0},
|
||||
{-2.0, -2.0},
|
||||
};
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK(array_equal(y, array(expected)).item<bool>());
|
||||
|
||||
y = fft::ifft(x);
|
||||
std::initializer_list<complex64_t> expected_inv = {
|
||||
{1.5, 0.0},
|
||||
{-0.5, -0.5},
|
||||
{-0.5, 0.0},
|
||||
{-0.5, 0.5},
|
||||
};
|
||||
CHECK(array_equal(y, array(expected_inv)).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
std::initializer_list<complex64_t> vals = {
|
||||
{1.0f, 1.0f}, {2.0f, 1.0f}, {1.0f, 2.0f}, {2.0f, 2.0f}};
|
||||
x = array(vals);
|
||||
y = fft::fft(x);
|
||||
std::initializer_list<complex64_t> expected = {
|
||||
{6.0, 6.0},
|
||||
{-1.0, -1.0},
|
||||
{-2.0, 0.0},
|
||||
{1.0, -1.0},
|
||||
};
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK(array_equal(y, array(expected)).item<bool>());
|
||||
CHECK(array_equal(fft::ifft(y), x).item<bool>());
|
||||
}
|
||||
|
||||
// Specify axes
|
||||
{
|
||||
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});
|
||||
std::initializer_list<complex64_t> expected_0 = {
|
||||
{2.0, 0.0},
|
||||
{4.0, 0.0},
|
||||
{-2.0, 0.0},
|
||||
{-2.0, 0.0},
|
||||
};
|
||||
y = fft::fft(x, 0);
|
||||
CHECK(array_equal(y, array(expected_0, {2, 2})).item<bool>());
|
||||
CHECK(array_equal(fft::ifft(y, 0), x).item<bool>());
|
||||
std::initializer_list<complex64_t> expected_1 = {
|
||||
{1.0, 0.0},
|
||||
{-1.0, 0.0},
|
||||
{5.0, 0.0},
|
||||
{-1.0, 0.0},
|
||||
};
|
||||
y = fft::fft(x, 1);
|
||||
CHECK(array_equal(y, array(expected_1, {2, 2})).item<bool>());
|
||||
CHECK(array_equal(fft::ifft(y, 1), x).item<bool>());
|
||||
}
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test real ffts") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
auto x = array({1.0});
|
||||
auto y = fft::rfft(x);
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
CHECK_EQ(y.size(), x.size());
|
||||
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});
|
||||
|
||||
{
|
||||
x = array({0.0f, 1.0f, 2.0f, 3.0f});
|
||||
y = fft::rfft(x);
|
||||
std::initializer_list<complex64_t> expected = {
|
||||
{6.0, 0.0}, {-2.0, 2.0}, {-2.0, -0.0}};
|
||||
CHECK_EQ(y.size(), x.size() / 2 + 1);
|
||||
CHECK(array_equal(y, array(expected)).item<bool>());
|
||||
}
|
||||
|
||||
x = array(complex64_t{1, 1});
|
||||
CHECK_THROWS(fft::irfft(x));
|
||||
|
||||
x = array({complex64_t{0, 1}, complex64_t{1, 0}});
|
||||
y = fft::irfft(x);
|
||||
CHECK_EQ(y.size(), 2);
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
CHECK(array_equal(y, array({0.5f, -0.5f})).item<bool>());
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test fftn") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
auto x = zeros({5, 5, 5});
|
||||
CHECK_THROWS_AS(fft::fftn(x, {}, {0, 3}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftn(x, {}, {0, 0}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftn(x, {5, 5, 5}, {0}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftn(x, {0}, {}, {}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(fft::fftn(x, {1, -1}, {}, {}), std::invalid_argument);
|
||||
|
||||
// Test 2D FFT
|
||||
{
|
||||
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});
|
||||
std::initializer_list<complex64_t> expected = {
|
||||
{6.0, 0.0},
|
||||
{-2.0, 0.0},
|
||||
{-4.0, 0.0},
|
||||
{0.0, 0.0},
|
||||
};
|
||||
auto y = fft::fft2(x);
|
||||
CHECK(array_equal(y, array(expected, {2, 2})).item<bool>());
|
||||
CHECK(array_equal(fft::ifft2(y), x).item<bool>());
|
||||
}
|
||||
|
||||
// Test 3D FFT
|
||||
{
|
||||
x = reshape(arange(8, float32), {2, 2, 2});
|
||||
std::initializer_list<complex64_t> expected = {
|
||||
{28.0, 0.0},
|
||||
{-4.0, 0.0},
|
||||
{-8.0, 0.0},
|
||||
{0.0, 0.0},
|
||||
{-16.0, 0.0},
|
||||
{0.0, 0.0},
|
||||
{0.0, 0.0},
|
||||
{0.0, 0.0},
|
||||
};
|
||||
auto y = fft::fftn(x);
|
||||
CHECK(array_equal(y, array(expected, {2, 2, 2})).item<bool>());
|
||||
CHECK(array_equal(fft::ifftn(y), x).item<bool>());
|
||||
|
||||
x = reshape(arange(20, float32), {5, 4});
|
||||
y = fft::rfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
y = fft::rfftn(x, {1, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{3, 4});
|
||||
|
||||
x = reshape(arange(20, float32), {5, 4});
|
||||
y = fft::irfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 6});
|
||||
y = fft::irfftn(x, {1, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{8, 4});
|
||||
}
|
||||
|
||||
// Check the types of real ffts
|
||||
{
|
||||
x = zeros({5, 5}, float32);
|
||||
auto y = fft::rfft2(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
|
||||
y = fft::rfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
CHECK_EQ(y.dtype(), complex64);
|
||||
|
||||
x = zeros({5, 5}, complex64);
|
||||
y = fft::irfft2(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
|
||||
y = fft::irfftn(x);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
|
||||
CHECK_EQ(y.dtype(), float32);
|
||||
}
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test fft with provided shape") {
|
||||
auto x = ones({5, 5});
|
||||
|
||||
auto y = fft::fft(x, 7, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{7, 5});
|
||||
|
||||
y = fft::fft(x, 3, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{3, 5});
|
||||
|
||||
y = fft::fft(x, 7, 1);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 7});
|
||||
|
||||
y = fft::fft(x, 3, 1);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||
|
||||
y = fft::rfft(x, 7, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{4, 5});
|
||||
|
||||
y = fft::rfft(x, 3, 0);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
|
||||
|
||||
y = fft::rfft(x, 3, 1);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{5, 2});
|
||||
}
|
||||
|
||||
TEST_CASE("test fft vmap") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
auto fft_fn = [](array x) { return fft::fft(x); };
|
||||
auto x = reshape(arange(8), {2, 4});
|
||||
auto y = vmap(fft_fn)(x);
|
||||
CHECK(array_equal(y, fft::fft(x)).item<bool>());
|
||||
|
||||
y = vmap(fft_fn, 1, 1)(x);
|
||||
CHECK(array_equal(y, fft::fft(x, 0)).item<bool>());
|
||||
|
||||
auto rfft_fn = [](array x) { return fft::rfft(x); };
|
||||
|
||||
y = vmap(rfft_fn)(x);
|
||||
CHECK(array_equal(y, fft::rfft(x)).item<bool>());
|
||||
|
||||
y = vmap(rfft_fn, 1, 1)(x);
|
||||
CHECK(array_equal(y, fft::rfft(x, 0)).item<bool>());
|
||||
|
||||
set_default_device(device);
|
||||
}
|
||||
|
||||
TEST_CASE("test fft grads") {
|
||||
auto device = default_device();
|
||||
set_default_device(Device::cpu);
|
||||
|
||||
// Regular
|
||||
auto fft_fn = [](array x) { return fft::fft(x); };
|
||||
auto cotangent = astype(arange(10), complex64);
|
||||
auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second;
|
||||
CHECK(array_equal(fft::fft(cotangent), vjp_out).item<bool>());
|
||||
|
||||
auto tangent = astype(arange(10), complex64);
|
||||
auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second;
|
||||
CHECK(array_equal(fft::fft(tangent), jvp_out).item<bool>());
|
||||
|
||||
// Inverse
|
||||
auto ifft_fn = [](array x) { return fft::ifft(x); };
|
||||
vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second;
|
||||
CHECK(array_equal(fft::ifft(cotangent), vjp_out).item<bool>());
|
||||
|
||||
jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second;
|
||||
CHECK(array_equal(fft::ifft(tangent), jvp_out).item<bool>());
|
||||
|
||||
// Real
|
||||
auto rfft_fn = [](array x) { return fft::rfft(x); };
|
||||
cotangent = astype(arange(6), complex64);
|
||||
vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second;
|
||||
auto expected = astype(fft::fft(cotangent, 10, 0), float32);
|
||||
CHECK(array_equal(expected, vjp_out).item<bool>());
|
||||
|
||||
tangent = astype(arange(10), float32);
|
||||
jvp_out = jvp(rfft_fn, zeros_like(tangent), tangent).second;
|
||||
CHECK(array_equal(fft::rfft(tangent), jvp_out).item<bool>());
|
||||
|
||||
// Inverse real
|
||||
auto irfft_fn = [](array x) { return fft::irfft(x); };
|
||||
cotangent = astype(arange(10), float32);
|
||||
vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second;
|
||||
expected = fft::fft(cotangent, 10, 0);
|
||||
auto o_splits = split(vjp_out, {1, 5});
|
||||
auto e_splits = split(expected, {1, 5, 6});
|
||||
CHECK_EQ(e_splits[0].item<complex64_t>(), o_splits[0].item<complex64_t>());
|
||||
CHECK(array_equal(2 * e_splits[1], o_splits[1]).item<bool>());
|
||||
CHECK_EQ(e_splits[2].item<complex64_t>(), o_splits[2].item<complex64_t>());
|
||||
|
||||
tangent = astype(arange(10), complex64);
|
||||
jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second;
|
||||
CHECK(array_equal(fft::irfft(tangent), jvp_out).item<bool>());
|
||||
|
||||
// Check ND vjps run properly
|
||||
vjp_out = vjp([](array x) { return fft::fftn(x); },
|
||||
astype(zeros({5, 5}), complex64),
|
||||
astype(zeros({5, 5}), complex64))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
|
||||
vjp_out = vjp([](array x) { return fft::ifftn(x); },
|
||||
astype(zeros({5, 5}), complex64),
|
||||
astype(zeros({5, 5}), complex64))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
|
||||
vjp_out = vjp([](array x) { return fft::rfftn(x); },
|
||||
zeros({5, 9}),
|
||||
astype(zeros({5, 5}), complex64))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 9});
|
||||
|
||||
vjp_out = vjp([](array x) { return fft::irfftn(x); },
|
||||
astype(zeros({5, 5}), complex64),
|
||||
zeros({5, 8}))
|
||||
.second;
|
||||
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||
|
||||
set_default_device(device);
|
||||
}
|
30
tests/graph_optimize_tests.cpp
Normal file
30
tests/graph_optimize_tests.cpp
Normal file
@ -0,0 +1,30 @@
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test simplify scalars") {
|
||||
auto a = array({-1.0f, 2.0f});
|
||||
auto b = maximum(a, array(0.0f));
|
||||
auto c = maximum(-a, array(0.0f));
|
||||
auto d = b + c;
|
||||
simplify({d});
|
||||
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = exp(a) + exp(a);
|
||||
simplify(b);
|
||||
eval(b);
|
||||
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
||||
}
|
||||
|
||||
TEST_CASE("test no simplify") {
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto b = cos(a) + sin(a);
|
||||
simplify(b);
|
||||
eval(b);
|
||||
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
||||
}
|
Loading…
Reference in New Issue
Block a user