angelos's commit files

This commit is contained in:
Angelos Katharopoulos 2023-11-29 10:42:59 -08:00
parent 8ca7f9e8e9
commit d1f86272a2
56 changed files with 12350 additions and 0 deletions

75
.gitignore vendored Normal file
View File

@ -0,0 +1,75 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Metal libraries
*.metallib
# Distribution / packaging
python/mlx/share
python/mlx/include
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# vim
*.swp
# Ignore build dir
build/
# Prerequisites
*.d
# Compiled Object files
*.slo
*.lo
*.o
*.obj
# Precompiled Headers
*.gch
*.pch
# Compiled Dynamic libraries
*.so
*.dylib
*.dll
# Fortran module files
*.mod
*.smod
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Executables
*.exe
*.out
*.app
# VSCode
.vscode/
.DS_Store

9
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,9 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v14.0.6
hooks:
- id: clang-format
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black

197
CMakeLists.txt Normal file
View File

@ -0,0 +1,197 @@
cmake_minimum_required(VERSION 3.24)
project(mlx LANGUAGES CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
# ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.0.1)
endif()
# ----------------------------- Lib -----------------------------
include(FetchContent)
# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
cmake_policy(SET CMP0135 NEW)
add_library(mlx)
if (MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
find_library(QUARTZ_LIB QuartzCore)
endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
add_compile_definitions(_METAL_)
execute_process(COMMAND zsh "-c" "/usr/bin/sw_vers | cut -f2- -d: | sed -n 2p | grep -Eo '[0-9]+.[0-9]+'"
OUTPUT_VARIABLE MACOS_VERSION)
message(STATUS "Detected macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
else()
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13_iOS16.zip)
endif()
FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
)
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
mlx PUBLIC
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>
)
target_link_libraries(
mlx
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
endif()
find_library(ACCELERATE_LIBRARY Accelerate)
if (ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
#set(BLA_VENDOR Generic)
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS ${BLAS_LIBRARIES})
message(STATUS ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
mlx
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>
)
if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()
if (MLX_BUILD_TESTS)
include(CTest)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
endif()
if (MLX_BUILD_EXAMPLES)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
endif()
if (MLX_BUILD_BENCHMARKS)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
endif()
# ----------------------------- Installation -----------------------------
include(GNUInstallDirs)
# Install library
install(
TARGETS mlx
EXPORT MLXTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
# Install headers
install(
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT headers
FILES_MATCHING PATTERN "*.h"
)
# Install metal dependencies
if (MLX_BUILD_METAL)
# Install metal cpp
install(
DIRECTORY ${metal_cpp_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
COMPONENT metal_cpp_source
)
endif()
# Install cmake config
set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake)
set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake)
set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
install(
EXPORT MLXTargets
FILE MLXTargets.cmake
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
include(CMakePackageConfigHelpers)
write_basic_package_version_file(
${MLX_CMAKE_BUILD_VERSION_CONFIG}
COMPATIBILITY SameMajorVersion
VERSION ${MLX_VERSION}
)
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
${MLX_CMAKE_BUILD_CONFIG}
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
NO_CHECK_REQUIRED_COMPONENTS_MACRO
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
)
install(
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)
install(
DIRECTORY ${CMAKE_MODULE_PATH}/
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
)

View File

@ -0,0 +1,38 @@
#pragma once
#include <chrono>
#include <iomanip>
#include <iostream>
#include "mlx/mlx.h"
#define milliseconds(x) \
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e6)
#define time_now() std::chrono::high_resolution_clock::now()
#define TIME(FUNC, ...) \
std::cout << "Timing " << #FUNC << " ... " << std::flush \
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " \
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
template <typename F, typename... Args>
double time_fn(F fn, Args... args) {
// warmup
for (int i = 0; i < 5; ++i) {
eval(fn(std::forward<Args>(args)...));
}
int num_iters = 100;
auto start = time_now();
for (int i = 0; i < num_iters; i++) {
eval(fn(std::forward<Args>(args)...));
}
auto end = time_now();
return milliseconds(end - start) / static_cast<double>(num_iters);
}

View File

@ -0,0 +1,60 @@
import argparse
import mlx.core as mx
from time_utils import time_fn
B = 8
T = 1024
D = 512
def time_batch_matmul():
mx.random.seed(3)
a = mx.random.uniform(shape=(B, T, D))
b = mx.random.uniform(shape=(D, D))
c = mx.random.uniform(shape=(B, T, D))
mx.eval(a, b, c)
time_fn(mx.matmul, a, b)
def batch_vjp_first():
return mx.vjp(mx.matmul, [a, b], [c])[1][0]
time_fn(batch_vjp_first)
def batch_vjp_second():
return mx.vjp(mx.matmul, [a, b], [c])[1][1]
time_fn(batch_vjp_second)
def time_unbatch_matmul(key):
mx.random.seed(3)
a = mx.random.uniform(shape=(B * T, D))
b = mx.random.uniform(shape=(D, D))
c = mx.random.uniform(shape=(B * T, D))
mx.eval(a, b, c)
time_fn(mx.matmul, a, b)
def unbatch_vjp_first():
return mx.matmul(c, mx.transpose(b))
time_fn(unbatch_vjp_first)
def unbatch_vjp_second():
return mx.matmul(mx.transpose(a), c)
time_fn(unbatch_vjp_second)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_batch_matmul()
time_unbatch_matmul()

View File

@ -0,0 +1,20 @@
import time
import mlx.core as mx
def time_fn(fn, *args, **kwargs):
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
mx.eval(fn(*args, **kwargs))
num_iters = 100
tic = time.perf_counter()
for _ in range(num_iters):
x = mx.eval(fn(*args, **kwargs))
toc = time.perf_counter()
msec = 1e3 * (toc - tic) / num_iters
print(f"{msec:.5f} msec")

2
docs/.clang-format Normal file
View File

@ -0,0 +1,2 @@
DisableFormat: true
SortIncludes: Never

18
docs/Makefile Normal file
View File

@ -0,0 +1,18 @@
# Minimal makefile for Sphinx documentation
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SOURCEDIR = src
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

131
docs/src/examples/mlp.rst Normal file
View File

@ -0,0 +1,131 @@
.. _mlp:
Multi-Layer Perceptron
----------------------
In this example we'll learn to use ``mlx.nn`` by implementing a simple
multi-layer perceptron to classify MNIST.
As a first step import the MLX packages we need:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
The model is defined as the ``MLP`` class which inherits from
:class:`mlx.nn.Module`. We follow the standard idiom to make a new module:
1. Define an ``__init__`` where the parameters and/or submodules are setup. See
the :ref:`Module class docs<module_class>` for more information on how
:class:`mlx.nn.Module` registers parameters.
2. Define a ``__call__`` where the computation is implemented.
.. code-block:: python
class MLP(nn.Module):
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = mx.maximum(l(x), 0.0)
return self.layers[-1](x)
We define the loss function which takes the mean of the per-example cross
entropy loss. The ``mlx.nn.losses`` sub-package has implementations of some
commonly used loss functions.
.. code-block:: python
def loss_fn(model, X, y):
return mx.mean(nn.losses.cross_entropy(model(X), y))
We also need a function to compute the accuracy of the model on the validation
set:
.. code-block:: python
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
Next, setup the problem parameters and load the data:
.. code-block:: python
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-1
# Load the data
import mnist
train_images, train_labels, test_images, test_labels = map(
mx.array, mnist.mnist()
)
Since we're using SGD, we need an iterator which shuffles and constructs
minibatches of examples in the training set:
.. code-block:: python
def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.size))
for s in range(0, y.size, batch_size):
ids = perm[s : s + batch_size]
yield X[ids], y[ids]
Finally, we put it all together by instantiating the model, the
:class:`mlx.optimizers.SGD` optimizer, and running the training loop:
.. code-block:: python
# Load the model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())
# Get a function which gives the loss and gradient of the
# loss with respect to the model's trainable parameters
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Instantiate the optimizer
optimizer = optim.SGD(learning_rate=learning_rate)
for e in range(num_epochs):
for X, y in batch_iterate(batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(model, X, y)
# Update the optimizer state and model parameters
# in a single call
optimizer.update(model, grads)
# Force a graph evaluation
mx.eval(model.parameters(), optimizer.state)
accuracy = eval_fn(model, test_images, test_labels)
print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}")
.. note::
The :func:`mlx.nn.value_and_grad` function is a convenience function to get
the gradient of a loss with respect to the trainable parameters of a model.
This should not be confused with :func:`mlx.core.value_and_grad`.
The model should train to a decent accuracy (about 95%) after just a few passes
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mlp>`_
is available in the MLX GitHub repo.

45
docs/src/python/array.rst Normal file
View File

@ -0,0 +1,45 @@
.. _array:
Array
=====
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
array
array.astype
array.item
array.tolist
array.dtype
array.ndim
array.shape
array.size
Dtype
array.abs
array.all
array.any
array.argmax
array.argmin
array.cos
array.dtype
array.exp
array.log
array.log1p
array.logsumexp
array.max
array.mean
array.min
array.prod
array.reciprocal
array.reshape
array.rsqrt
array.sin
array.split
array.sqrt
array.square
array.sum
array.transpose
array.T
array.var

94
docs/src/python/ops.rst Normal file
View File

@ -0,0 +1,94 @@
.. _ops:
Operations
==========
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
abs
add
all
allclose
any
arange
arccos
arccosh
arcsin
arcsinh
arctan
arctanh
argmax
argmin
argpartition
argsort
array_equal
broadcast_to
concatenate
convolve
conv1d
conv2d
cos
cosh
divide
equal
erf
erfinv
exp
expand_dims
full
greater
greater_equal
less
less_equal
load
log
log2
log10
log1p
logaddexp
logical_not
logsumexp
matmul
max
maximum
mean
min
minimum
multiply
negative
ones
ones_like
partition
pad
prod
reciprocal
reshape
rsqrt
save
savez
savez_compressed
sigmoid
sign
sin
sinh
softmax
sort
split
sqrt
square
squeeze
stop_gradient
subtract
sum
take
take_along_axis
tan
tanh
transpose
var
where
zeros
zeros_like

18
examples/cpp/timer.h Normal file
View File

@ -0,0 +1,18 @@
#pragma once
#include <chrono>
namespace timer {
using namespace std::chrono;
template <typename R, typename P>
inline double seconds(duration<R, P> x) {
return duration_cast<nanoseconds>(x).count() / 1e9;
}
inline auto time() {
return high_resolution_clock::now();
}
} // namespace timer

View File

@ -0,0 +1,359 @@
#include <cassert>
#include <iostream>
#include <sstream>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/utils.h"
#include "axpby/axpby.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h>
#endif
#ifdef _METAL_
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#endif
namespace mlx::core {
///////////////////////////////////////////////////////////////////////////////
// Operation Implementation
///////////////////////////////////////////////////////////////////////////////
/**
* Scale and sum two vectors elementwise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = is_floating_point(promoted_dtype)
? promoted_dtype
: promote_types(promoted_dtype, float32);
// Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = astype(x, out_dtype, s);
auto y_casted = astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
auto out_shape = broadcasted_inputs[0].shape();
// Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs
return array(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_unique<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
///////////////////////////////////////////////////////////////////////////////
// Primitive Common Backend Implementation
///////////////////////////////////////////////////////////////////////////////
template <typename T>
void axpby_impl(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Collect input and output data pointers
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
T* out_ptr = out.data<T>();
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the elementwise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additonal mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
}
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while contructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == float16) {
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == bfloat16) {
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == complex64) {
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
}
}
///////////////////////////////////////////////////////////////////////////////
// Primitive Accelerate Backend Implementation
///////////////////////////////////////////////////////////////////////////////
#ifdef ACCELERATE_NEW_LAPACK
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common backend if specializations are not available
eval(inputs, out);
}
#else // Accelerate not avaliable
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
#endif
///////////////////////////////////////////////////////////////////////////////
// Primitive Metal Backend Implementation
///////////////////////////////////////////////////////////////////////////////
#ifdef _METAL_
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
auto& s = stream();
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
// Prepare to specialize based on contiguity
bool contiguous_kernel =
(x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous);
// Allocate output memory with strides based on specialization
if (contiguous_kernel) {
out.set_data(
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
x.data_size(),
x.strides(),
x.flags());
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
// Resolve name of kernel (corresponds to axpby.metal)
std::ostringstream kname;
kname << "axpby_";
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Make sure the metal library is available and look for it
// in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel decelaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
// Encode input arrays to kernel
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
// Encode output arrays to kernel
set_array_buffer(compute_encoder, out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim if needed
if (!contiguous_kernel) {
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
}
// We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
// Fix the 3D size of each threadgroup (in terms of threads)
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divded among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
#else // Metal is not available
/** Fail evaluation on GPU */
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("Axpby has no GPU implementation.");
}
#endif
///////////////////////////////////////////////////////////////////////////////
// Primitive Transforms
///////////////////////////////////////////////////////////////////////////////
/** The Jacobian-vector product. */
array Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the the primitive can built with ops
// that are scheduled on the same stream as the primtive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
// Similarly, if argnums = {1}, the jvp is just the tangent
// scaled by beta
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return multiply(scale_arr, tangents[0], stream());
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
}
}
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotan.dtype());
vjps.push_back(multiply(scale_arr, cotan, stream()));
}
return vjps;
}
/** Vectorize primitve along given axis */
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation.");
}
/** Equivalence check **/
bool Axpby::is_equivalent(const Primitive& other) const {
const Axpby& r_other = static_cast<const Axpby&>(other);
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
}
} // namespace mlx::core

View File

@ -0,0 +1,39 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "axpby/axpby.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors elementwise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
Args:
x (array): Input array.
y (array): Input array.
alpha (float): Scaling factor for ``x``.
beta (float): Scaling factor for ``y``.
Returns:
array: ``alpha * x + beta * y``
)pbdoc");
}

36
mlx/CMakeLists.txt Normal file
View File

@ -0,0 +1,36 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
)
endif()
if (MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
endif()

143
mlx/array.cpp Normal file
View File

@ -0,0 +1,143 @@
#include <functional>
#include "mlx/array.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
namespace mlx::core {
namespace {
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
std::vector<size_t> strides(shape.size());
size_t cum_prod = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = cum_prod;
cum_prod *= shape[i];
}
return {cum_prod, strides};
}
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
auto cval = static_cast<complex64_t>(val);
init(&cval);
}
array::array(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
shape,
dtype,
std::move(primitive),
inputs)) {}
array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
float32)) {
init(data.begin());
}
/* Build an array from a shared buffer */
array::array(
allocator::Buffer data,
const std::vector<int>& shape,
Dtype dtype,
deleter_t deleter)
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
set_data(data, deleter);
}
void array::detach() {
array_desc_->inputs.clear();
array_desc_->primitive = nullptr;
}
void array::eval(bool retain_graph /* = false */) {
mlx::core::eval({*this}, retain_graph);
}
void array::set_data(allocator::Buffer buffer, deleter_t d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = size();
array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true;
auto max_dim = std::max_element(shape().begin(), shape().end());
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
}
void array::set_data(
allocator::Buffer buffer,
size_t data_size,
std::vector<size_t> strides,
Flags flags,
deleter_t d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides);
array_desc_->flags = flags;
}
void array::copy_shared_buffer(
const array& other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = other.array_desc_->data;
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
}
void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
: shape(shape), dtype(dtype) {
std::tie(size, strides) = cum_prod(shape);
}
array::ArrayDesc::ArrayDesc(
const std::vector<int>& shape,
Dtype dtype,
std::unique_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: shape(shape),
dtype(dtype),
primitive(std::move(primitive)),
inputs(inputs) {
std::tie(size, strides) = cum_prod(shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
// Needed because the Primitive type used in array.h is incomplete and the
// compiler needs to see the call to the desctructor after the type is complete.
array::ArrayDesc::~ArrayDesc() = default;
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
auto start = std::vector<int>(arr.ndim(), 0);
auto end = arr.shape();
auto shape = arr.shape();
shape.erase(shape.begin());
start[0] = idx;
end[0] = idx + 1;
return reshape(slice(arr, start, end), shape);
};
} // namespace mlx::core

View File

@ -0,0 +1,323 @@
#include <cassert>
#include <limits>
#include <arm_neon.h>
#include <simd/math.h>
#include <simd/vector.h>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
/**
* Compute exp(x) in an optimizer friendly way as follows:
*
* First change the problem to computing 2**y where y = x / ln(2).
*
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
* `ipart` and y2 is fractional part. For the integer part we perform bit
* shifting and for the fractional part we use a polynomial approximation.
*
* The algorithm and constants of the polynomial taken from
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
* from Cephes math library.
*
* Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0.
*/
inline simd_float16 simd_fast_exp(simd_float16 x) {
x *= 1.442695; // multiply with log_2(e)
simd_float16 ipart, fpart;
simd_int16 epart;
x = simd_clamp(x, -80, 80);
ipart = simd::floor(x + 0.5);
fpart = x - ipart;
x = 1.535336188319500e-4f;
x = x * fpart + 1.339887440266574e-3f;
x = x * fpart + 9.618437357674640e-3f;
x = x * fpart + 5.550332471162809e-2f;
x = x * fpart + 2.402264791363012e-1f;
x = x * fpart + 6.931472028550421e-1f;
x = x * fpart + 1.000000000000000f;
// generate 2**ipart in the floating point representation using integer
// bitshifting
epart = (simd_int(ipart) + 127) << 23;
return (*(simd_float16*)&epart) * x;
}
/**
* The ARM neon equivalent of the fast exp above.
*/
inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
float16x8_t fpart = vsubq_f16(x, ipart);
x = vdupq_n_f16(1.535336188319500e-4f);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
// generate 2**ipart in the floating point representation using integer
// bitshifting
int16x8_t epart = vcvtq_s16_f16(ipart);
epart = vaddq_s16(epart, vdupq_n_s16(15));
epart = vshlq_n_s16(epart, 10);
return vmulq_f16(vreinterpretq_f16_s16(epart), x);
}
/**
* Implementation of folding maximum for ARM neon. This should possibly be
* refactored out of softmax.cpp at some point.
*/
inline float16_t neon_reduce_max(float16x8_t x) {
float16x4_t y;
y = vpmax_f16(vget_low_f16(x), vget_high_f16(x));
y = vpmax_f16(y, y);
y = vpmax_f16(y, y);
return vget_lane_f16(y, 0);
}
/**
* Implementation of folding sum for ARM neon. This should possibly be
* refactored out of softmax.cpp at some point.
*/
inline float16_t neon_reduce_add(float16x8_t x) {
float16x4_t y;
float16x4_t zero = vdup_n_f16(0);
y = vpadd_f16(vget_low_f16(x), vget_high_f16(x));
y = vpadd_f16(y, zero);
y = vpadd_f16(y, zero);
return vget_lane_f16(y, 0);
}
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
};
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
return vdupq_n_f16(a);
}
VT load(const T* a) {
return vld1q_f16(a);
}
void store(T* dst, VT x) {
vst1q_f16(dst, x);
}
VT max(VT a, VT b) {
return vmaxq_f16(a, b);
};
VT exp(VT x) {
return neon_fast_exp(x);
}
VT add(VT a, VT b) {
return vaddq_f16(a, b);
}
VT sub(VT a, T b) {
return vsubq_f16(a, vdupq_n_f16(b));
}
VT mul(VT a, VT b) {
return vmulq_f16(a, b);
}
VT mul(VT a, T b) {
return vmulq_f16(a, vdupq_n_f16(b));
}
T reduce_max(VT x) {
return neon_reduce_max(x);
}
T reduce_add(VT x) {
return neon_reduce_add(x);
}
};
template <typename T, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
size_t s = M;
while (s >= N) {
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
current_in_ptr += N;
s -= N;
}
T maximum = ops.reduce_max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, *current_in_ptr);
current_in_ptr++;
}
// Compute the normalizer and the exponentials
VT vnormalizer = ops.init(0.0);
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
ops.store(current_out_ptr, vexp);
*(VT*)current_out_ptr = vexp;
vnormalizer = ops.add(vnormalizer, vexp);
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
T normalizer = ops.reduce_add(vnormalizer);
while (s-- > 0) {
T _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = _exp;
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
s = M;
while (s >= N) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
*current_out_ptr *= normalizer;
current_out_ptr++;
}
}
}
} // namespace
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
if (x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General);
return x_copy;
}
};
array in = check_input(std::move(inputs[0]));
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::invalid_argument(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
in, out);
break;
case float16:
softmax<
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
break;
case bfloat16:
eval(inputs, out);
break;
case complex64:
eval(inputs, out);
break;
}
}
} // namespace mlx::core

View File

@ -0,0 +1,216 @@
#include <cassert>
#include <cmath>
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
template <typename T, typename U, typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
DefaultScalarVector<T, U, Op> opsv(op);
DefaultVectorScalar<T, U, Op> opvs(op);
DefaultVectorVector<T, U, Op> opvv(op);
binary_op<T, U>(a, b, out, op, opsv, opvs, opvv);
}
template <typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
switch (a.dtype()) {
case bool_:
comparison_op<bool, bool>(a, b, out, op);
break;
case uint8:
comparison_op<uint8_t, bool>(a, b, out, op);
break;
case uint16:
comparison_op<uint16_t, bool>(a, b, out, op);
break;
case uint32:
comparison_op<uint32_t, bool>(a, b, out, op);
break;
case uint64:
comparison_op<uint64_t, bool>(a, b, out, op);
break;
case int8:
comparison_op<int8_t, bool>(a, b, out, op);
break;
case int16:
comparison_op<int16_t, bool>(a, b, out, op);
break;
case int32:
comparison_op<int32_t, bool>(a, b, out, op);
break;
case int64:
comparison_op<int64_t, bool>(a, b, out, op);
break;
case float16:
comparison_op<float16_t, bool>(a, b, out, op);
break;
case float32:
comparison_op<float, bool>(a, b, out, op);
break;
case bfloat16:
comparison_op<bfloat16_t, bool>(a, b, out, op);
break;
case complex64:
comparison_op<complex64_t, bool>(a, b, out, op);
break;
}
}
} // namespace
void Add::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return x + y; });
}
void Divide::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return x / y; });
}
void Equal::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (equal_nan_) {
comparison_op(inputs[0], inputs[1], out, [](auto x, auto y) {
return x == y || (std::isnan(x) && std::isnan(y));
});
} else {
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x == y; });
}
}
void Greater::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x > y; });
}
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x >= y; });
}
void Less::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x < y; });
}
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x <= y; });
}
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto op = [](auto x, auto y) {
constexpr float inf = std::numeric_limits<float>::infinity();
auto maxval = (x > y) ? x : y;
auto minval = (x > y) ? y : x;
return (minval == -inf || maxval == inf)
? maxval
: static_cast<decltype(x)>(
maxval + std::log1p(std::exp(minval - maxval)));
};
if (is_floating_point(out.dtype())) {
if (out.dtype() == float32) {
binary_op<float>(a, b, out, op);
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, op);
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, op);
} else {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
}
} else {
throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with"
" non floating point type.");
}
}
void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
void Minimum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
void Multiply::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return x * y; });
}
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(
inputs[0], inputs[1], out, [](auto x, auto y) { return x != y; });
}
struct PowerFn {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
return std::pow(base, exp);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
if (exp < 0) {
throw std::invalid_argument(
"Integers cannot be raise to negative powers");
}
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
};
void Power::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, PowerFn{});
}
void Subtract::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, [](auto x, auto y) { return x - y; });
}
} // namespace mlx::core

554
mlx/backend/common/binary.h Normal file
View File

@ -0,0 +1,554 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
namespace {
enum BinaryOpType {
ScalarScalar,
ScalarVector,
VectorScalar,
VectorVector,
General,
};
BinaryOpType get_binary_op_type(const array& a, const array& b) {
BinaryOpType bopt;
if (a.data_size() == 1 && b.data_size() == 1) {
bopt = ScalarScalar;
} else if (a.data_size() == 1 && b.flags().contiguous) {
bopt = ScalarVector;
} else if (b.data_size() == 1 && a.flags().contiguous) {
bopt = VectorScalar;
} else if (
a.flags().row_contiguous && b.flags().row_contiguous ||
a.flags().col_contiguous && b.flags().col_contiguous) {
bopt = VectorVector;
} else {
bopt = General;
}
return bopt;
}
void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt) {
switch (bopt) {
case ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case ScalarVector:
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
break;
case VectorScalar:
case VectorVector:
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
break;
case General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
break;
}
}
struct UseDefaultBinaryOp {
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
// Should we throw? This should normally never be called.
assert(false);
}
};
template <typename T, typename U, typename Op>
struct DefaultVectorScalar {
Op op;
DefaultVectorScalar(Op op_) : op(op_) {}
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *b;
while (size-- > 0) {
*dst = op(*a, scalar);
dst++;
a++;
}
}
};
template <typename T, typename U, typename Op>
struct DefaultScalarVector {
Op op;
DefaultScalarVector(Op op_) : op(op_) {}
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *a;
while (size-- > 0) {
*dst = op(scalar, *b);
dst++;
b++;
}
}
};
template <typename T, typename U, typename Op>
struct DefaultVectorVector {
Op op;
DefaultVectorVector(Op op_) : op(op_) {}
void operator()(const T* a, const T* b, U* dst, int size) {
while (size-- > 0) {
*dst = op(*a, *b);
dst++;
a++;
b++;
}
}
};
template <typename T, typename U, typename Op>
void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < out.size(); ++i) {
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims1(
const array& a,
const array& b,
array& out,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; i++) {
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
dst += stride;
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims2(
const array& a,
const array& b,
array& out,
Op op,
int stride) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
dst += stride;
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op) {
switch (out.ndim()) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out, op);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out, op);
return;
case 3:
binary_op_dims3<T, U, Op>(a, b, out, op);
return;
case 4:
binary_op_dims4<T, U, Op>(a, b, out, op);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
}
}
template <typename T, typename U, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op,
int dim,
int stride) {
// Number of dimensions to loop over for vectorized ops
switch (dim) {
case 1:
binary_op_dims1<T, U, Op>(a, b, out, op, stride);
return;
case 2:
binary_op_dims2<T, U, Op>(a, b, out, op, stride);
return;
}
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i += stride) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
dst += stride;
}
}
template <
typename T,
typename U,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == ScalarVector) {
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == VectorScalar) {
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == VectorVector) {
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
// General computation so let's try to optimize
// Get the left-most dim such that the array is row contiguous after
auto& strides = out.strides();
auto leftmost_rc_dim = [&strides](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a);
auto b_rc_dim = leftmost_rc_dim(b);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const array& arr) {
int d = arr.ndim() - 1;
for (; d >= 0 && arr.strides()[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a);
auto b_s_dim = leftmost_s_dim(b);
auto ndim = out.ndim();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
break;
case VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
break;
case ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
break;
default:
binary_op_dispatch_dims<T, U>(a, b, out, op);
break;
}
}
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv and opvs were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
}
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opvs was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
}
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
} else {
// All ops provided
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
}
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
DefaultScalarVector<T, T, Op> opsv(op);
DefaultVectorScalar<T, T, Op> opvs(op);
DefaultVectorVector<T, T, Op> opvv(op);
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
}
template <typename... Ops>
void binary(const array& a, const array& b, array& out, Ops... ops) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, ops...);
break;
case uint8:
binary_op<uint8_t>(a, b, out, ops...);
break;
case uint16:
binary_op<uint16_t>(a, b, out, ops...);
break;
case uint32:
binary_op<uint32_t>(a, b, out, ops...);
break;
case uint64:
binary_op<uint64_t>(a, b, out, ops...);
break;
case int8:
binary_op<int8_t>(a, b, out, ops...);
break;
case int16:
binary_op<int16_t>(a, b, out, ops...);
break;
case int32:
binary_op<int32_t>(a, b, out, ops...);
break;
case int64:
binary_op<int64_t>(a, b, out, ops...);
break;
case float16:
binary_op<float16_t>(a, b, out, ops...);
break;
case float32:
binary_op<float>(a, b, out, ops...);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, ops...);
break;
case complex64:
binary_op<complex64_t>(a, b, out, ops...);
break;
}
}
} // namespace
} // namespace mlx::core

View File

@ -0,0 +1,85 @@
#include <numeric>
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/allocator.h"
#include "mlx/primitives.h"
namespace mlx::core {
void FFT::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
std::vector<std::ptrdiff_t> strides_in(
in.strides().begin(), in.strides().end());
for (auto& s : strides_in) {
s *= in.itemsize();
}
std::vector<std::ptrdiff_t> strides_out(
out.strides().begin(), out.strides().end());
for (auto& s : strides_out) {
s *= out.itemsize();
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<size_t> shape;
if (out.dtype() == float32) {
shape.insert(shape.end(), out.shape().begin(), out.shape().end());
} else {
shape.insert(shape.end(), in.shape().begin(), in.shape().end());
}
float scale = 1.0f;
if (inverse_) {
size_t nelem = std::accumulate(
axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) {
return x * shape[y];
});
scale /= nelem;
}
if (in.dtype() == complex64 && out.dtype() == complex64) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::c2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
} else if (in.dtype() == float32 && out.dtype() == complex64) {
auto in_ptr = in.data<float>();
auto out_ptr =
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
pocketfft::r2c(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
} else if (in.dtype() == complex64 && out.dtype() == float32) {
auto in_ptr =
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
auto out_ptr = out.data<float>();
pocketfft::c2r(
shape,
strides_in,
strides_out,
axes_,
!inverse_,
in_ptr,
out_ptr,
scale);
} else {
throw std::runtime_error(
"[FFT] Received unexpected input and output type combination.");
}
}
} // namespace mlx::core

View File

@ -0,0 +1,98 @@
#include <cassert>
#include <cmath>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T>
void softmax(const array& in, array& out) {
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int N = in.shape().back();
int M = in.data_size() / N;
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
// Find the maximum
current_in_ptr = in_ptr;
T maximum = *current_in_ptr;
for (int j = 0; j < N; j++, current_in_ptr++) {
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum;
}
// Compute the normalizer and the exponentials
T normalizer = 0;
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
T expv = std::exp(*current_in_ptr - maximum);
normalizer += expv;
*current_out_ptr = expv;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
for (int j = 0; j < N; j++, current_out_ptr++) {
*current_out_ptr *= normalizer;
}
}
}
} // namespace
void Softmax::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
if (x.strides().back() == 1) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General);
return x_copy;
}
};
array in = check_input(std::move(inputs[0]));
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::invalid_argument(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float>(in, out);
break;
case float16:
softmax<float16_t>(in, out);
break;
case bfloat16:
softmax<bfloat16_t>(in, out);
break;
case complex64:
throw std::invalid_argument(
"[Softmax] Not yet implemented for complex64");
break;
}
}
} // namespace mlx::core

View File

@ -0,0 +1,200 @@
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
#include <mach/vm_page_size.h>
#include <unistd.h>
#include <cstdlib>
namespace mlx::core {
namespace allocator {
Allocator& allocator() {
return metal::allocator();
}
void* Buffer::raw_ptr() {
return static_cast<MTL::Buffer*>(ptr_)->contents();
}
} // namespace allocator
namespace metal {
namespace {
BufferCache::BufferCache(MTL::Device* device)
: device_(device),
head_(nullptr),
tail_(nullptr),
pool_size_(0),
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
BufferCache::~BufferCache() {
clear();
}
void BufferCache::clear() {
std::lock_guard<std::mutex> lk(cache_mutex_);
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf)
holder->buf->release();
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
}
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
std::lock_guard<std::mutex> lk(cache_mutex_);
// Find the closest buffer in pool
MTL::Buffer* pbuf = nullptr;
auto it = buffer_pool_.lower_bound(size);
// Make sure we use > 50% of the available memory
while (!pbuf && it != buffer_pool_.end() && it->first < 2 * size) {
// Collect from the cache
pbuf = it->second->buf;
// Remove from cache
remove_from_list(it->second);
delete it->second;
it = buffer_pool_.erase(it);
}
if (pbuf) {
pool_size_ -= pbuf->length();
}
return pbuf;
}
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
std::lock_guard<std::mutex> lk(cache_mutex_);
// Add to cache
if (buf) {
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
pool_size_ += buf->length();
buffer_pool_.insert({buf->length(), bh});
}
}
size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
min_bytes_to_free += device_->currentAllocatedSize() - gc_limit_;
if (min_bytes_to_free >= 0.9 * pool_size_) {
size_t old_pool_size = pool_size_;
clear();
return old_pool_size;
} else {
std::lock_guard<std::mutex> lk(cache_mutex_);
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) {
total_bytes_freed += tail_->buf->length();
tail_->buf->release();
tail_->buf = nullptr;
}
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return total_bytes_freed;
}
}
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
if (!to_add)
return;
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
if (!to_remove)
return;
// If in the middle
if (to_remove->prev && to_remove->next) {
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // If tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // If head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // If only element
head_ = nullptr;
tail_ = nullptr;
}
to_remove->prev = nullptr;
to_remove->next = nullptr;
}
} // namespace
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
buffer_cache_(device_),
peak_allocated_size_(0),
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()) {}
Buffer MetalAllocator::malloc(size_t size) {
// Align up memory
if (size > vm_page_size) {
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
}
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
// Prepare to allocate new memory as needed
if (!buf) {
// If we are under very high memoory pressure, we don't allocate further
if (device_->currentAllocatedSize() >= block_limit_) {
return Buffer{nullptr};
}
// If we are still under memory pressure, try cleaning cache
if (buffer_cache_.can_garbage_collect()) {
buffer_cache_.release_cached_buffers(size);
}
// Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeTracked;
buf = device_->newBuffer(size, res_opt);
}
peak_allocated_size_ =
std::max(peak_allocated_size_, device_->currentAllocatedSize());
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
buffer_cache_.recycle_to_cache(buf);
}
MetalAllocator& allocator() {
static MetalAllocator allocator_;
return allocator_;
}
} // namespace metal
} // namespace mlx::core

View File

@ -0,0 +1,76 @@
#pragma once
#include <map>
#include <mutex>
#include <vector>
#include "mlx/allocator.h"
#include "mlx/backend/metal/device.h"
namespace mlx::core::metal {
using allocator::Buffer;
namespace {
class BufferCache {
public:
BufferCache(MTL::Device* device);
~BufferCache();
void clear();
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
size_t release_cached_buffers(size_t min_bytes_to_free);
bool can_garbage_collect() {
return pool_size_ > 0 && device_->currentAllocatedSize() > gc_limit_;
}
private:
struct BufferHolder {
public:
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
BufferHolder* prev;
BufferHolder* next;
MTL::Buffer* buf;
};
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
std::mutex cache_mutex_;
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
BufferHolder* tail_;
size_t pool_size_;
size_t gc_limit_;
};
} // namespace
class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
private:
MTL::Device* device_;
MetalAllocator();
friend MetalAllocator& allocator();
// Caching allocator
BufferCache buffer_cache_;
// Allocation stats
size_t peak_allocated_size_;
size_t block_limit_;
};
MetalAllocator& allocator();
} // namespace mlx::core::metal

555
mlx/backend/metal/conv.cpp Normal file
View File

@ -0,0 +1,555 @@
#include <algorithm>
#include <cassert>
#include <iostream>
#include <numeric>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/conv_params.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/matmul.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
void explicit_gemm_conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<1>& conv_params) {
// Pad input
std::vector<int> padded_shape = {
conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.C};
array in_padded(padded_shape, in.dtype(), nullptr, {});
// Fill with zeros
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
// Make strided view
std::vector<int> strided_shape = {
conv_params.N, conv_params.oS[0], conv_params.wS[0], conv_params.C};
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * conv_params.str[0],
in_padded.strides()[1],
in_padded.strides()[2]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {
conv_params.N * conv_params.oS[0], conv_params.wS[0] * conv_params.C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
// Peform gemm
std::vector<array> copies = {in_padded, in_strided};
mlx_matmul(
s,
d,
/*a = */ in_strided,
/*b = */ wt,
/*c = */ out,
/*M = */ strided_reshape[0],
/*N = */ conv_params.O,
/*K = */ strided_reshape[1],
/*batch_size_out = */ 1,
/*a_cols = */ strided_reshape[1],
/*b_cols = */ strided_reshape[1],
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(2),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1)},
/* const int wS[NDIM] = */ {wt.shape(1)},
/* const int oS[NDIM] = */ {out.shape(1)},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int dil[NDIM] = */ {wt_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
};
// Direct to explicit gemm conv
if (wt_dilation[0] == 1) {
explicit_gemm_conv_1D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to fallback conv
else {
throw std::invalid_argument("[conv_1D_gpu] Dilation needs to be 1.");
}
}
void slow_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
int bm = 16, bn = 8;
int tm = 4, tn = 4;
std::ostringstream kname;
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
<< "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
size_t grid_dim_z = conv_params.N;
MTL::Size group_dims = MTL::Size(bm, bn, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
size_t grid_dim_x = (implicit_N + bn - 1) / bn;
size_t grid_dim_y = (implicit_M + bm - 1) / bm;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void explicit_gemm_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
// Pad input
std::vector<int> padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
conv_params.C};
array in_padded(padded_shape, in.dtype(), nullptr, {});
// Fill with zeros
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
conv_params.pad[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
// Make strided view
std::vector<int> strided_shape = {
conv_params.N,
conv_params.oS[0],
conv_params.oS[1],
conv_params.wS[0],
conv_params.wS[1],
conv_params.C};
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * conv_params.str[0],
in_padded.strides()[2] * conv_params.str[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {
conv_params.N * conv_params.oS[0] * conv_params.oS[1],
conv_params.wS[0] * conv_params.wS[1] * conv_params.C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
// Peform gemm
std::vector<array> copies = {in_padded, in_strided};
mlx_matmul(
s,
d,
/*a = */ in_strided,
/*b = */ wt,
/*c = */ out,
/*M = */ strided_reshape[0],
/*N = */ conv_params.O,
/*K = */ strided_reshape[1],
/*batch_size_out = */ 1,
/*a_cols = */ strided_reshape[1],
/*b_cols = */ strided_reshape[1],
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
}
void winograd_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies_w) {
std::vector<int> padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
conv_params.C};
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
array in_padded(padded_shape, in.dtype(), nullptr, {});
// Fill with zeros
array zero_arr = array(0, in.dtype());
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
conv_params.pad[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
copies_w.push_back(in_padded_slice);
copies_w.push_back(in_padded);
copies_w.push_back(zero_arr);
MLXConvParams<2> conv_params_updated{
/* const int N = */ in_padded.shape(0),
/* const int C = */ in_padded.shape(3),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int str[NDIM] = */ {1, 1},
/* const int pad[NDIM] = */ {0, 0},
/* const int dil[NDIM] = */ {1, 1},
/* const size_t in_strides[NDIM + 2] = */
{in_padded.strides()[0],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
};
int O_c = conv_params.O;
int C_c = conv_params.C;
int N_tiles_n = conv_params.N;
int N_tiles_h = (conv_params.oS[0] + 5) / 6;
int N_tiles_w = (conv_params.oS[1] + 5) / 6;
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
// Do filter transform
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
copies_w.push_back(filt_wg);
{
int bc = 32;
int bo = 4;
std::ostringstream kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
<< bc;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, wt, 0);
set_array_buffer(compute_encoder, filt_wg, 1);
compute_encoder->setBytes(&C_c, sizeof(int), 2);
compute_encoder->setBytes(&O_c, sizeof(int), 3);
MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Do input transform
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
copies_w.push_back(inp_wg);
{
int bc = 32;
int wm = 2;
int wn = 2;
std::ostringstream kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
<< bc;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in_padded, 0);
set_array_buffer(compute_encoder, inp_wg, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Do batched gemm
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
copies_w.push_back(out_wg);
{
std::vector<array> empty_copies;
mlx_matmul(
s,
d,
/*a = */ inp_wg,
/*b = */ filt_wg,
/*c = */ out_wg,
/*M = */ N_tiles,
/*N = */ conv_params.O,
/*K = */ conv_params.C,
/*batch_size_out = */ 8 * 8,
/*a_cols = */ conv_params.C,
/*b_cols = */ conv_params.O,
/*a_transposed = */ false,
/*b_transposed = */ false,
/*copies = */ empty_copies);
}
// Do output transform
{
int bc = 32;
int wm = 2;
int wn = 2;
std::ostringstream kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
<< bc;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, out_wg, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
}
void conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
std::vector<array>& copies) {
// Make conv params
MLXConvParams<2> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(3),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
/* const int pad[NDIM] = */ {padding[0], padding[1]},
/* const int dil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
};
// Direct to winograd conv
if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
conv_params.C >= 64 && conv_params.O >= 64 && conv_params.wS[0] == 3 &&
conv_params.wS[1] == 3 && conv_params.str[0] == 1 &&
conv_params.str[1] == 1 && conv_params.dil[0] == 1 &&
conv_params.dil[1] == 1) {
winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
else if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0) {
implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else if (wt_dilation[0] == 1 && wt_dilation[1] == 1) {
explicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to fallback conv
else {
slow_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
// Ensure contiguity
std::vector<array> copies;
auto in = inputs[0];
auto wt = inputs[1];
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
in = arr_copy;
}
if (!wt.flags().row_contiguous) {
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
copy_gpu(wt, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
wt = arr_copy;
}
// 2D conv
if (out.ndim() == 4) {
conv_2D_gpu(
s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_, copies);
}
// 1D conv
else if (out.ndim() == 3) {
conv_1D_gpu(s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_);
}
// Throw error
else {
throw std::invalid_argument(
"[Convolution::eval_gpu] Only supports 1D or 2D convolutions.");
}
// Clear copies
if (copies.size() > 0) {
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
}
} // namespace mlx::core

16
mlx/backend/metal/copy.h Normal file
View File

@ -0,0 +1,16 @@
#pragma once
#include "mlx/backend/common/copy.h"
#include "mlx/stream.h"
namespace mlx::core {
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype);
void copy_gpu_inplace(
const array& src,
array& out,
CopyType ctype,
const Stream& s);
} // namespace mlx::core

View File

@ -0,0 +1,30 @@
#include "mlx/backend/metal/kernels/bf16.h"
template <typename T>
[[kernel]] void arange(
constant const T& start,
constant const T& step,
device T* out,
uint index [[thread_position_in_grid]]) {
out[index] = start + index * step;
}
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] \
[[kernel]] void arange<type>( \
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t)
instantiate_arange(uint32, uint32_t)
instantiate_arange(uint64, uint64_t)
instantiate_arange(int8, int8_t)
instantiate_arange(int16, int16_t)
instantiate_arange(int32, int32_t)
instantiate_arange(int64, int64_t)
instantiate_arange(float16, half)
instantiate_arange(float32, float)
instantiate_arange(bfloat16, bfloat16_t)

View File

@ -0,0 +1,208 @@
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
template <typename U>
struct IndexValPair {
uint32_t index;
U val;
IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {}
};
template <typename U>
struct ArgMin {
static constexpr constant U init = Limits<U>::max;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val > current.val || (best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i=0; i<N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
best.index = offset+i;
}
}
return best;
}
};
template <typename U>
struct ArgMax {
static constexpr constant U init = Limits<U>::min;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val < current.val || (best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i=0; i<N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
best.index = offset+i;
}
}
return best;
}
};
bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>(
simd_shuffle_down(data.index, delta),
simd_shuffle_down(data.val, delta)
);
}
template <typename T, typename Op, int N_READS>
[[kernel]] void arg_reduce_general(
const device T *in [[buffer(0)]],
device uint32_t *out [[buffer(1)]],
const device int *shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]],
threadgroup IndexValPair<T> *local_data [[threadgroup(0)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Shapes and strides *do not* contain the reduction axis. The reduction size
// and stride are provided in axis_stride and axis_size.
//
// Note: in shape == out shape with this convention.
//
// The sketch of the kernel is as follows.
// 1. Launch prod(shape) * thread_group_size threads.
// 2. Loop ceildiv(axis_size / lsize) times
// 3. Read input values
// 4. Reduce among them and go to 3
// 4. Reduce in each simd_group
// 6. Write in the thread local memory
// 6. Reduce them accross thread group
// 7. Write the output without need for atomic
Op op;
// Compute the input/output index. There is one beginning and one output for
// the whole threadgroup.
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
IndexValPair<T> best(0, Op::init);
// Loop over the reduction axis in lsize*N_READS buckets
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
// Read the current value
uint32_t current_index = r*lsize*N_READS + lid*N_READS;
uint32_t offset = current_index;
const device T * current_in = in + in_idx + current_index * axis_stride;
T vals[N_READS];
for (int i=0; i<N_READS; i++) {
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
current_index++;
current_in += axis_stride;
}
best = op.template reduce_many<N_READS>(best, vals, offset);
}
// At this point we have reduced the axis into thread group best values so we
// need to reduce across the thread group.
// First per simd reduction.
for (uint offset=simd_size/2; offset>0; offset/=2) {
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
best = op.reduce(best, neighbor);
}
// Write to the threadgroup memory
if (simd_lane_id == 0) {
local_data[simd_group_id] = best;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id != 0) {
return;
}
// Read the appropriate value from local data and perform one simd reduction
uint simd_groups = ceildiv(lsize, simd_size);
if (simd_lane_id < simd_groups) {
best = local_data[simd_lane_id];
}
for (uint offset=simd_size/2; offset>0; offset/=2) {
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
best = op.reduce(best, neighbor);
}
// Finally write the output
if (lid == 0) {
out[out_idx] = best.index;
}
}
#define instantiate_arg_reduce_helper(name, itype, op) \
template [[host_name(name)]] \
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \
const device itype *in [[buffer(0)]], \
device uint32_t * out [[buffer(1)]], \
const device int *shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
threadgroup IndexValPair<itype> *local_data [[threadgroup(0)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_arg_reduce(name, itype) \
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
instantiate_arg_reduce(bool_, bool)
instantiate_arg_reduce(uint8, uint8_t)
instantiate_arg_reduce(uint16, uint16_t)
instantiate_arg_reduce(uint32, uint32_t)
instantiate_arg_reduce(uint64, uint64_t)
instantiate_arg_reduce(int8, int8_t)
instantiate_arg_reduce(int16, int16_t)
instantiate_arg_reduce(int32, int32_t)
instantiate_arg_reduce(int64, int64_t)
instantiate_arg_reduce(float16, half)
instantiate_arg_reduce(float32, float)
instantiate_arg_reduce(bfloat16, bfloat16_t)

View File

@ -0,0 +1,392 @@
#pragma once
#include "mlx/backend/metal/kernels/bf16.h"
///////////////////////////////////////////////////////////////////////////////
// Metal math for bfloat16
///////////////////////////////////////////////////////////////////////////////
/*
Following the Metal Shading Language Specification (Metal 3.1)
"bfloat is an extended itypeing point type that only allows implicit conversion
to a type of greater itypeing point rank. While bfloat can be implicitly
converted to itype, it cannot be implicitly converted to half, and neither
itype nor half can be implicitly converted to bfloat."
Further, as far as I can tell, the stdlib math/simd functions are not defined
for bfloat and calling with an argument of type bfloat will result in that
argument getting implicitly converted to itype which then returns an output
that is (likely) a itype which cannot be implicitly converted into a bfloat
This leads to situations where
bfloat a = 5.0bf;
bfloat b = metal::abs(a); // this will throw an error since abs return itype
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
For the moment, I will be adding overloaded instantiations of the math
functions to accordingly automatically handle the casting
*/
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
\
METAL_FUNC otype abs(itype x) { \
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype acos(itype x) { \
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype acosh(itype x) { \
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype asin(itype x) { \
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype asinh(itype x) { \
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype atan(itype y_over_x) { \
return static_cast<otype>( \
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
} \
METAL_FUNC otype atan2(itype y, itype x) { \
return static_cast<otype>( \
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype atanh(itype x) { \
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype ceil(itype x) { \
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cos(itype x) { \
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cosh(itype x) { \
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cospi(itype x) { \
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype divide(itype x, itype y) { \
return static_cast<otype>( \
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype exp(itype x) { \
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype exp10(itype x) { \
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype exp2(itype x) { \
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fabs(itype x) { \
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fdim(itype x, itype y) { \
ctype t = static_cast<ctype>(x - y); \
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
} \
METAL_FUNC otype floor(itype x) { \
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fma(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fma( \
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
} \
METAL_FUNC otype fmax(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmax3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmedian3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype fmin(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmin3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype fmod(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fract(itype x) { \
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype frexp(itype x, thread int& exp) { \
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
} \
METAL_FUNC otype ldexp(itype x, int k) { \
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
} \
METAL_FUNC otype log(itype x) { \
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype log10(itype x) { \
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype log2(itype x) { \
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype max(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype max3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmax3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype median3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmedian3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype min(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype min3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmin3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype nextafter(itype x, itype y) { \
return static_cast<otype>( \
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
} \
METAL_FUNC otype pow(itype x, itype y) { \
return static_cast<otype>( \
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype powr(itype x, itype y) { \
return static_cast<otype>( \
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype rint(itype x) { \
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype round(itype x) { \
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype rsqrt(itype x) { \
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sin(itype x) { \
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sinh(itype x) { \
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sinpi(itype x) { \
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sqrt(itype x) { \
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tan(itype x) { \
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tanh(itype x) { \
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tanpi(itype x) { \
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype trunc(itype x) { \
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
}
namespace metal {
instantiate_metal_math_funcs(
bfloat16_t,
bfloat16_t,
float,
__METAL_MAYBE_FAST_MATH__);
namespace fast {
instantiate_metal_math_funcs(
bfloat16_t,
bfloat16_t,
float,
__METAL_FAST_MATH__);
} // namespace fast
namespace precise {
instantiate_metal_math_funcs(
bfloat16_t,
bfloat16_t,
float,
__METAL_PRECISE_MATH__);
} // namespace precise
} // namespace metal
///////////////////////////////////////////////////////////////////////////////
// Metal simd for bfloat16
///////////////////////////////////////////////////////////////////////////////
#define instantiate_metal_simd_comm_funcs( \
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
\
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
return ctype_to_otype( \
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
} \
\
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
return ctype_to_otype( \
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_down( \
itype data, itype filling_data, ushort delta, ushort modulo) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_down( \
itype data, itype filling_data, ushort delta) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
itype_to_ctype(data), \
itype_to_ctype(filling_data), \
delta, \
__metal_get_simdgroup_size(ushort()))); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_up( \
itype data, itype filling_data, ushort delta, ushort modulo) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_up( \
itype data, itype filling_data, ushort delta) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
itype_to_ctype(data), \
itype_to_ctype(filling_data), \
delta, \
__metal_get_simdgroup_size(ushort()))); \
} \
\
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
return ctype_to_otype( \
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
}
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
\
METAL_FUNC otype simd_max(itype data) { \
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_min(itype data) { \
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_product(itype data) { \
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_sum(itype data) { \
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_xor(itype data) { \
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if defined(__HAVE_BFLOAT__)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
#else
#define bfloat16_to_uint16(x) x.bits_
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
#endif
namespace metal {
instantiate_metal_simd_comm_funcs(
bfloat16_t,
bfloat16_t,
uint16_t,
bfloat16_to_uint16,
uint16_to_bfloat16);
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
} // namespace metal

View File

@ -0,0 +1,553 @@
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/conv_params.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/gemm/conv.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
/// Slow and naive kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const int BC = 16>
[[kernel]] void naive_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)simd_gid;
(void)simd_lid;
out += tid.z * params.out_strides[0];
in += tid.z * params.in_strides[0];
int out_o = tid.y * BN * TN + lid.y * TN;
int out_hw = tid.x * BM * TM + lid.x * TM;
int out_h[TM];
int out_w[TN];
for(int m = 0; m < TM; ++m) {
int mm = (out_hw + m);
out_h[m] = mm / params.oS[1];
out_w[m] = mm % params.oS[1];
}
T in_local[TM];
T wt_local[TN];
T out_local[TM * TN] = {T(0)};
for(int h = 0; h < params.wS[0]; ++h) {
for(int w = 0; w < params.wS[1]; ++w) {
for(int c = 0; c < params.C; ++c) {
// Local in
for(int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.dil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.dil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
}
// Load weight
for (int n = 0; n < TN; ++n) {
int o = out_o + n;
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] +
h * params.wt_strides[1] +
w * params.wt_strides[2] + c] : T(0);
}
// Accumulate
for(int m = 0; m < TM; ++m) {
for(int n = 0; n < TN; ++n) {
out_local[m * TN + n] += in_local[m] * wt_local[n];
}
}
}
}
}
for(int m = 0; m < TM; ++m) {
for(int n = 0; n < TN; ++n) {
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O)
out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n];
}
}
}
// Instantiations
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_naive_conv_2d_blocks(name, itype) \
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Implicit gemm kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemm_kernel = Conv2DImplicitGEMMKernel<T, BM, BN, BK, WM, WN, /*transpose_a*/ false, /*transpose_b*/ true>;
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
gemm_kernel::run(
in, wt, out,
params, tgp_memory,
tid, lid, simd_gid, simd_lid
);
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_conv_2d(name, itype, 32, 32, 32, 2, 2) \
instantiate_implicit_conv_2d(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_conv_2d(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels
///////////////////////////////////////////////////////////////////////////////
template <int M, int R, int S>
struct WinogradTransforms {
};
template <>
struct WinogradTransforms<6, 3, 8> {
MLX_MTL_CONST int OUT_TILE_SIZE = 6;
MLX_MTL_CONST int FILTER_SIZE = 3;
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{ 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
{ 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
{ 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
{ 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
};
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{ 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
{ 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
{ 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
{ 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
{ 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
{ 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
};
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00, 0.00, 0.00},
{ -2.0/9.00, -2.0/9.00, -2.0/9.00},
{ -2.0/9.00, 2.0/9.00, -2.0/9.00},
{ 1.0/90.0, 1.0/45.0, 2.0/45.0},
{ 1.0/90.0, -1.0/45.0, 2.0/45.0},
{ 32.0/45.0, 16.0/45.0, 8.0/45.0},
{ 32.0/45.0, -16.0/45.0, 8.0/45.0},
{ 0.00, 0.00, 1.00},
};
};
constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
template <typename T,
int BC = 32,
int BO = 4,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
const device T* wt_in [[buffer(0)]],
device T* wt_out [[buffer(1)]],
const constant int& C [[buffer(2)]],
const constant int& O [[buffer(3)]],
uint tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
using WGT = WinogradTransforms<M, R, 8>;
// Get lane position in simdgroup
const short qid = simd_lane_id / 4;
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize G matrix
simdgroup_matrix<T, 8, 8> G;
G.thread_elements()[0] = WGT::wt_transform[sm][sn];
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
// Initialize Gt matrix
simdgroup_matrix<T, 8, 8> Gt;
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
// Move to the correct output filter
size_t ko = BO * tid + simd_group_id;
wt_in += ko * R * R * C;
// wt_out is stored transposed (A x A x C x O)
short ohw_0 = sm * 8 + sn;
short ohw_1 = sm * 8 + sn + 1;
device T* wt_out_0 = wt_out + ohw_0 * C * O + ko;
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
// Prepare shared memory
threadgroup T Ws[BO][R][R][BC];
// Loop over C
for(int bc = 0; bc < C; bc += BC) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read into shared memory
for(int kh = 0; kh < R; ++kh) {
for(int kw = 0; kw < R; ++kw) {
for(int kc = simd_lane_id; kc < BC; kc += 32) {
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for(int c = 0; c < BC; ++c) {
simdgroup_matrix<T, 8, 8> g;
g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
wt_out_0[c * O] = g_out.thread_elements()[0];
wt_out_1[c * O] = g_out.thread_elements()[1];
}
wt_in += BC;
wt_out_0 += BC * O;
wt_out_1 += BC * O;
}
}
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\
[[kernel]] void winograd_conv_2d_weight_transform<itype, bc>(\
const device itype* wt_in [[buffer(0)]],\
device itype* wt_out [[buffer(1)]],\
const constant int& C [[buffer(2)]],\
const constant int& O [[buffer(3)]],\
uint tid [[threadgroup_position_in_grid]],\
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
uint simd_lane_id [[thread_index_in_simdgroup]]);
template <typename T,
int BC,
int WM,
int WN,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform(
const device T* inp_in [[buffer(0)]],
device T* inp_out [[buffer(1)]],
const constant MLXConvParams<2>& params [[buffer(2)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_per_grid [[threadgroups_per_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
(void)lid;
using WGT = WinogradTransforms<M, R, 8>;
constexpr int A = WGT::IN_TILE_SIZE;
constexpr int N_SIMD_GROUPS = WM * WN;
// Get lane position in simdgroup
const short qid = simd_lane_id / 4;
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize B matrix
simdgroup_matrix<T, 8, 8> B;
B.thread_elements()[0] = WGT::in_transform[sm][sn];
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
// Initialize Bt matrix
simdgroup_matrix<T, 8, 8> Bt;
Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
// Resolve input tile
constexpr int TH = (A / WM);
constexpr int TW = (A / WN);
int kh = TH * (simd_group_id / WN);
int kw = TW * (simd_group_id % WN);
int bh = M * tid.y + kh;
int bw = M * tid.x + kw;
// Move to the correct input tile
inp_in += tid.z * params.in_strides[0]
+ bh * params.in_strides[1]
+ bw * params.in_strides[2];
// Pre compute strides
int jump_in[TH][TW];
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
}
}
// inp_out is stored interleaved (A x A x tiles x C)
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t ohw_0 = sm * 8 + sn;
size_t ohw_1 = sm * 8 + sn + 1;
device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
// Prepare shared memory
threadgroup T Is[A][A][BC];
// Loop over C
for(int bc = 0; bc < params.C; bc += BC) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read into shared memory
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
const device T* in_ptr = inp_in + jump_in[h][w];
for(int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = in_ptr[c];
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for(int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> I;
I.thread_elements()[0] = Is[sm][sn][c];
I.thread_elements()[1] = Is[sm][sn + 1][c];
simdgroup_matrix<T, 8, 8> I_out = (Bt * I) * B;
inp_out_0[c] = I_out.thread_elements()[0];
inp_out_1[c] = I_out.thread_elements()[1];
}
inp_in += BC;
inp_out_0 += BC;
inp_out_1 += BC;
}
}
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\
[[kernel]] void winograd_conv_2d_input_transform<itype, bc, 2, 2>(\
const device itype* inp_in [[buffer(0)]],\
device itype* inp_out [[buffer(1)]],\
const constant MLXConvParams<2>& params [[buffer(2)]],\
uint3 tid [[threadgroup_position_in_grid]],\
uint3 lid [[thread_position_in_threadgroup]],\
uint3 tgp_per_grid [[threadgroups_per_grid]],\
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
uint simd_lane_id [[thread_index_in_simdgroup]]);
template <typename T,
int BO,
int WM,
int WN,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform(
const device T* out_in [[buffer(0)]],
device T* out_out [[buffer(1)]],
const constant MLXConvParams<2>& params [[buffer(2)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_per_grid [[threadgroups_per_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
(void)lid;
using WGT = WinogradTransforms<M, R, 8>;
constexpr int N_SIMD_GROUPS = WM * WN;
// Get lane position in simdgroup
const short qid = simd_lane_id / 4;
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize A matrix
simdgroup_matrix<T, 8, 8> B;
B.thread_elements()[0] = WGT::out_transform[sm][sn];
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
// Initialize At matrix
simdgroup_matrix<T, 8, 8> Bt;
Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
// Out_in comes in shape (A x A x tiles x O)
// We do transform and then write out to out_out in shape (N, H, W, O)
// Resolve output tile
constexpr int TH = (M / WM);
constexpr int TW = (M / WN);
int kh = TH * (simd_group_id / WN);
int kw = TW * (simd_group_id % WN);
int bh = M * tid.y + kh;
int bw = M * tid.x + kw;
// Move to the correct input tile
out_out += tid.z * params.out_strides[0]
+ bh * params.out_strides[1]
+ bw * params.out_strides[2];
// Pre compute strides
int jump_in[TH][TW];
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
}
}
// out_in is stored interleaved (A x A x tiles x O)
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t ohw_0 = sm * 8 + sn;
size_t ohw_1 = sm * 8 + sn + 1;
const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
// Prepare shared memory
threadgroup T Os[M][M][BO];
// Loop over O
for(int bo = 0; bo < params.O; bo += BO) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for(int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> O_mat;
O_mat.thread_elements()[0] = out_in_0[c];
O_mat.thread_elements()[1] = out_in_1[c];
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
if((sm < M) && (sn < M)) {
Os[sm][sn][c] = O_out.thread_elements()[0];
}
if((sm < M) && ((sn + 1) < M)) {
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read out from shared memory
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
if(jump_in[h][w] >= 0) {
device T* out_ptr = out_out + jump_in[h][w];
for(int c = simd_lane_id; c < BO; c += 32) {
out_ptr[c] = Os[kh + h][kw + w][c];
}
}
}
}
out_out += BO;
out_in_0 += BO;
out_in_1 += BO;
}
}
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\
[[kernel]] void winograd_conv_2d_output_transform<itype, bo, 2, 2>(\
const device itype* out_in [[buffer(0)]],\
device itype* out_out [[buffer(1)]],\
const constant MLXConvParams<2>& params [[buffer(2)]],\
uint3 tid [[threadgroup_position_in_grid]],\
uint3 lid [[thread_position_in_threadgroup]],\
uint3 tgp_per_grid [[threadgroups_per_grid]],\
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
uint simd_lane_id [[thread_index_in_simdgroup]]);
#define instantiate_winograd_conv_2d(name, itype) \
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
instantiate_winograd_conv_2d_output_transform(name, itype, 32)
instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(float16, half);

View File

@ -0,0 +1,269 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, typename U>
[[kernel]] void copy_s(
device const T* src,
device U* dst,
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v(
device const T* src,
device U* dst,
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src,
device U* dst,
constant const size_t& src_stride,
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd2(
device const T* src,
device U* dst,
constant const size_t src_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd3(
device const T* src,
device U* dst,
constant const size_t src_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd(
device const T* src,
device U* dst,
constant const int src_shape[DIM],
constant const size_t src_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g(
device const T* src,
device U* dst,
constant const int* src_shape,
constant const size_t* src_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd1(
device const T* src,
device U* dst,
constant const size_t& src_stride,
constant const size_t& dst_stride,
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd2(
device const T* src,
device U* dst,
constant const size_t src_strides[2],
constant const size_t dst_strides[2],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd3(
device const T* src,
device U* dst,
constant const size_t src_strides[3],
constant const size_t dst_strides[3],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd(
device const T* src,
device U* dst,
constant const int src_shape[DIM],
constant const size_t src_strides[DIM],
constant const size_t dst_strides[DIM],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg(
device const T* src,
device U* dst,
constant const int* src_shape,
constant const size_t* src_strides,
constant const size_t* dst_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
#define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] \
[[kernel]] void copy_##ctype<itype, otype>( \
device const itype* src, \
device otype* dst, \
uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void copy_g_nd<itype, otype, dims>( \
device const itype* src, \
device otype* dst, \
constant const int src_shape[dims], \
constant const size_t src_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] \
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
device const itype* src, \
device otype* dst, \
constant const int src_shape[dims], \
constant const size_t src_strides[dims], \
constant const size_t dst_strides[dims], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] \
[[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t& src_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] \
[[kernel]] void copy_gg_nd1<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t& src_stride, \
constant const size_t& dst_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] \
[[kernel]] void copy_gg_nd2<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[2], \
constant const size_t dst_strides[2], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] \
[[kernel]] void copy_gg_nd3<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[3], \
constant const size_t dst_strides[3], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] \
[[kernel]] void copy_g<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const int* src_shape, \
constant const size_t* src_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] \
[[kernel]] void copy_gg<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const int* src_shape, \
constant const size_t* src_strides, \
constant const size_t* dst_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("scopy" #tname, itype, otype, s) \
instantiate_copy("vcopy" #tname, itype, otype, v) \
instantiate_copy_g("gcopy" #tname, itype, otype) \
instantiate_copy_g_nd("gcopy" #tname, itype, otype)
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
instantiate_copy_all(itname ##int8, itype, int8_t) \
instantiate_copy_all(itname ##int16, itype, int16_t) \
instantiate_copy_all(itname ##int32, itype, int32_t) \
instantiate_copy_all(itname ##int64, itype, int64_t) \
instantiate_copy_all(itname ##float16, itype, half) \
instantiate_copy_all(itname ##float32, itype, float) \
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
instantiate_copy_all(itname ##complex64, itype, complex64_t)
instantiate_copy_itype(bool_, bool)
instantiate_copy_itype(uint8, uint8_t)
instantiate_copy_itype(uint16, uint16_t)
instantiate_copy_itype(uint32, uint32_t)
instantiate_copy_itype(uint64, uint64_t)
instantiate_copy_itype(int8, int8_t)
instantiate_copy_itype(int16, int16_t)
instantiate_copy_itype(int32, int32_t)
instantiate_copy_itype(int64, int64_t)
instantiate_copy_itype(float16, half)
instantiate_copy_itype(float32, float)
instantiate_copy_itype(bfloat16, bfloat16_t)
instantiate_copy_itype(complex64, complex64_t)

View File

@ -0,0 +1,68 @@
#pragma once
#include <metal_math>
/*
* Approximation to the error function.
* Based on code from:
* https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
*/
float erf(float a) {
float r, s, t, u;
t = metal::abs(a);
s = a * a;
if (t > 0.927734375f) {
// maximum error 0.99527 ulp
r = metal::fma(
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
u = metal::fma(
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
r = metal::fma(r, s, u);
r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
r = metal::fma(r, t, -t);
// TODO, replace with expm1 when implemented
r = 1.0f - metal::exp(r);
r = metal::copysign(r, a);
} else {
// maximum error 0.98929 ulp
r = -5.96761703e-4f; // -0x1.38e000p-11
r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
r = metal::fma(r, a, a);
}
return r;
}
float erfinv(float a) {
auto t = metal::fma(a, 0.0f - a, 1.0f);
t = metal::log(t);
float p;
if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
} else { // maximum ulp error = 2.35002
p = 5.43877832e-9f; // 0x1.75c000p-28
p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
}
return a * p;
}

View File

@ -0,0 +1,91 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/gemm/gemm.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *C [[buffer(2)]],
const constant int &M [[buffer(3)]],
const constant int &N [[buffer(4)]],
const constant int &K [[buffer(5)]],
const constant int &batch_stride_a [[buffer(6)]],
const constant int &batch_stride_b [[buffer(7)]],
const constant int &batch_stride_c [[buffer(8)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
gemm_kernel::run(
A, B, C,
M, N, K,
batch_stride_a, batch_stride_b, batch_stride_c,
tgp_memory,
simd_lane_id, simd_group_id, tid, lid
);
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *C [[buffer(2)]], \
const constant int &M [[buffer(3)]], \
const constant int &N [[buffer(4)]], \
const constant int &K [[buffer(5)]], \
const constant int &batch_stride_a [[buffer(6)]], \
const constant int &batch_stride_b [[buffer(7)]], \
const constant int &batch_stride_c [[buffer(8)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
// TODO: Accumulation in different type

View File

@ -0,0 +1,99 @@
#include "mlx/backend/metal/kernels/utils.h"
static constexpr constant uint32_t rotations[2][4] = {
{13, 15, 26, 6},
{17, 29, 16, 24}
};
union rbits {
uint2 val;
uchar4 bytes[2];
};
rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
rbits v;
v.val.x = count.x + ks[0];
v.val.y = count.y + ks[1];
for (int i = 0; i < 5; ++i) {
for (auto r : rotations[i % 2]) {
v.val.x += v.val.y;
v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));
v.val.y ^= v.val.x;
}
v.val.x += ks[(i + 1) % 3];
v.val.y += ks[(i + 2) % 3] + i + 1;
}
return v;
}
[[kernel]] void rbitsc(
device const uint32_t* keys,
device char* out,
device const bool& odd,
device const uint& bytes_per_key,
uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x;
auto key = uint2(keys[kidx], keys[kidx + 1]);
auto half_size = grid_dim.y - odd;
out += index.x * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
auto bits = threefry2x32_hash(key, count);
for (int i = 0; i < 4; ++i) {
out[4 * count.x + i] = bits.bytes[0][i];
}
if (!drop_last) {
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
}
} else {
for (int i = 0; i < 4; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
}
}
}
}
[[kernel]] void rbits(
device const uint32_t* keys,
device char* out,
device const bool& odd,
device const uint& bytes_per_key,
device const int& ndim,
device const int* key_shape,
device const size_t* key_strides,
uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x;
auto k1_elem = elem_to_loc(kidx, key_shape, key_strides, ndim);
auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim);
auto key = uint2(keys[k1_elem], keys[k2_elem]);
auto half_size = grid_dim.y - odd;
out += index.x * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
auto bits = threefry2x32_hash(key, count);
for (int i = 0; i < 4; ++i) {
out[4 * count.x + i] = bits.bytes[0][i];
}
if (!drop_last) {
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
}
} else {
for (int i = 0; i < 4; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
}
}
}
}

View File

@ -0,0 +1,536 @@
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/reduce.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
static constant uint8_t simd_size = 32;
template <typename T, typename Op>
[[kernel]] void init_reduce(
device T *out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {
out[tid] = Op::init;
}
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i" #name)]] \
[[kernel]] void init_reduce<otype, op>( \
device otype *out [[buffer(1)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// All reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// NB: this kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
Op op;
threadgroup U local_vals[simd_size];
U total_val = Op::init;
in += gid * N_READS;
int r = 0;
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for(int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for(int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Sepate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for(int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
// Reduction within simd group
total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
// Reduction within thread group
threadgroup_barrier(mem_flags::mem_threadgroup);
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
total_val = op.simd_reduce(total_val);
// Reduction across threadgroups
if (lid == 0) {
op.atomic_update(out, total_val);
}
}
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// General reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op>
[[kernel]] void general_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device int *in_shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim);
auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim);
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
}
template <typename T, typename U, typename Op, int NDIM>
[[kernel]] void general_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device int *in_shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto in_idx = elem_to_loc_nd<NDIM>(gid, in_shape, in_strides);
auto out_idx = elem_to_loc_nd<NDIM>(gid, in_shape, out_strides);
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
}
#define instantiate_general_reduce_helper(name, itype, otype, op) \
template [[host_name("general_reduce_" #name)]] \
[[kernel]] void general_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device int *in_shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
uint gid [[thread_position_in_grid]]);
#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \
template [[host_name("general_reduce_" #name "_dim_" #n)]] \
[[kernel]] void general_reduce<itype, otype, op, n>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device int *in_shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
uint gid [[thread_position_in_grid]]);
#define instantiate_general_reduce(name, itype, otype, op) \
instantiate_general_reduce_helper(name, itype, otype, op) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \
instantiate_general_reduce_helper_nd(name, itype, otype, op, 4)
///////////////////////////////////////////////////////////////////////////////
// Row atomics
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device size_t& reduction_size [[buffer(2)]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
// Each threadgroup handles 1 reduction
in += tid * reduction_size + lid * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
threadgroup U local_vals[simd_size];
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize * N_READS;
}
// Sepate case for the last set as we close the reduction size
size_t reduction_index = (lid + (size_t)lsize * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
total_val = op.simd_reduce(total_val);
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if multiple simd groups
if(reduction_size > simd_size) {
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
total_val = op.simd_reduce(total_val);
}
// Update output
if (lid == 0) {
out[tid] = total_val;
}
}
#define instantiate_row_reduce(name, itype, otype, op) \
template [[host_name("row_reduce_" #name)]] \
[[kernel]] void row_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& reduction_size [[buffer(2)]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Column reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
inline void _contiguous_strided_reduce(
const device T *in,
device mlx_atomic<U> *out,
threadgroup U *local_data,
uint in_idx,
uint out_idx,
uint reduction_size,
uint reduction_stride,
uint2 tid,
uint2 lid,
uint2 lsize) {
Op op;
T local_vals[N_READS];
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
for(uint r = 0; r < N_READS; r++) {
uint offset = base_offset + r;
offset = offset < reduction_size ? offset : reduction_size - 1;
local_vals[r] = in[in_idx + offset * reduction_stride];
}
U total_val = Op::init;
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
total_val = op(static_cast<U>(total_val), local_vals[r]);
}
local_data[lsize.y * lid.x + lid.y] = total_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
if(lid.y == 0) {
U val = op.init;
for(uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]);
}
op.atomic_update(out, val, out_idx);
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
threadgroup U *local_data [[threadgroup(0)]],
uint2 tid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
out_idx,
out_idx,
reduction_size,
reduction_stride,
tid,
lid,
lsize);
}
}
#define instantiate_col_reduce(name, itype, otype, op) \
template [[host_name("col_reduce_" #name)]] \
[[kernel]] void col_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint2 tid [[threadgroup_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]]);
template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
[[kernel]] void contiguous_strided_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const device int* in_shape [[buffer(5)]],
const device size_t* in_strides [[buffer(6)]],
threadgroup U *local_data [[threadgroup(0)]],
uint2 tid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc_nd<NDIM>(out_idx, in_shape, in_strides);
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
in_idx,
out_idx,
reduction_size,
reduction_stride,
tid,
lid,
lsize);
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void contiguous_strided_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const device int* in_shape [[buffer(5)]],
const device size_t* in_strides [[buffer(6)]],
const device size_t& in_dim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint2 tid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim);
if(out_idx < out_size) {
_contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
in_idx,
out_idx,
reduction_size,
reduction_stride,
tid,
lid,
lsize);
}
}
#define instantiate_contiguous_strided_helper(name, itype, otype, op) \
template [[host_name("contiguous_strided_reduce_" #name)]] \
[[kernel]] void contiguous_strided_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const device int* in_shape [[buffer(5)]], \
const device size_t* in_strides [[buffer(6)]], \
const device size_t& in_dim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint2 tid [[threadgroup_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]]);
#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \
template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \
[[kernel]] void contiguous_strided_reduce<itype, otype, op, n>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const device int* in_shape [[buffer(5)]], \
const device size_t* in_strides [[buffer(6)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint2 tid [[threadgroup_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]]);
#define instantiate_contiguous_strided(name, itype, otype, op) \
instantiate_contiguous_strided_helper(name, itype, otype, op) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4)
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_reduce(name, itype, otype, op) \
instantiate_all_reduce(name, itype, otype, op) \
instantiate_row_reduce(name, itype, otype, op) \
instantiate_col_reduce(name, itype, otype, op) \
instantiate_contiguous_strided(name, itype, otype, op) \
instantiate_general_reduce(name, itype, otype, op)
#define instantiate_same_reduce(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce(name ##tname, type, type, op<type>)
#define instantiate_reduce_from_types_helper(name, tname, itype, otype, op) \
instantiate_reduce(name ##tname, itype, otype, op)
#define instantiate_reduce_from_types(name, otype, op) \
instantiate_reduce_from_types_helper(name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper(name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper(name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper(name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper(name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper(name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper(name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper(name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper(name, float16, half, otype, op) \
instantiate_reduce_from_types_helper(name, float32, float, otype, op) \
instantiate_reduce_from_types_helper(name, bfloat16, bfloat16_t, otype, op)
// special case bool with larger output type
instantiate_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_same_reduce(sum, uint8, uint8_t, Sum)
instantiate_same_reduce(sum, uint16, uint16_t, Sum)
instantiate_same_reduce(sum, uint32, uint32_t, Sum)
instantiate_same_reduce(sum, int8, int8_t, Sum)
instantiate_same_reduce(sum, int16, int16_t, Sum)
instantiate_same_reduce(sum, int32, int32_t, Sum)
instantiate_same_reduce(sum, float16, half, Sum)
instantiate_same_reduce(sum, float32, float, Sum)
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
instantiate_same_reduce(prod, int8, int8_t, Prod)
instantiate_same_reduce(prod, int16, int16_t, Prod)
instantiate_same_reduce(prod, int32, int32_t, Prod)
instantiate_same_reduce(prod, float16, half, Prod)
instantiate_same_reduce(prod, float32, float, Prod)
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
instantiate_init_reduce(andbool_, bool, And)
instantiate_reduce_from_types(and, bool, And)
instantiate_init_reduce(orbool_, bool, Or)
instantiate_reduce_from_types(or, bool, Or)
// Compiler segfaulted with the names "min" or "max" ...
instantiate_same_reduce(min_, uint8, uint8_t, Min)
instantiate_same_reduce(min_, uint16, uint16_t, Min)
instantiate_same_reduce(min_, uint32, uint32_t, Min)
instantiate_same_reduce(min_, int8, int8_t, Min)
instantiate_same_reduce(min_, int16, int16_t, Min)
instantiate_same_reduce(min_, int32, int32_t, Min)
instantiate_same_reduce(min_, float16, half, Min)
instantiate_same_reduce(min_, float32, float, Min)
instantiate_same_reduce(max_, uint8, uint8_t, Max)
instantiate_same_reduce(max_, uint16, uint16_t, Max)
instantiate_same_reduce(max_, uint32, uint32_t, Max)
instantiate_same_reduce(max_, int8, int8_t, Max)
instantiate_same_reduce(max_, int16, int16_t, Max)
instantiate_same_reduce(max_, int32, int32_t, Max)
instantiate_same_reduce(max_, float16, half, Max)
instantiate_same_reduce(max_, float32, float, Max)
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)

View File

@ -0,0 +1,284 @@
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/bf16.h"
struct Abs {
template <typename T> T operator()(T x) { return metal::abs(x); };
template <> uint8_t operator()(uint8_t x) { return x; };
template <> uint16_t operator()(uint16_t x) { return x; };
template <> uint32_t operator()(uint32_t x) { return x; };
template <> uint64_t operator()(uint64_t x) { return x; };
template <> bool operator()(bool x) { return x; };
template <> complex64_t operator()(complex64_t x) {
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
};
};
struct ArcCos {
template <typename T> T operator()(T x) { return metal::precise::acos(x); };
};
struct ArcCosh {
template <typename T> T operator()(T x) { return metal::precise::acosh(x); };
};
struct ArcSin {
template <typename T> T operator()(T x) { return metal::precise::asin(x); };
};
struct ArcSinh {
template <typename T> T operator()(T x) { return metal::precise::asinh(x); };
};
struct ArcTan {
template <typename T> T operator()(T x) { return metal::precise::atan(x); };
};
struct ArcTanh {
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
};
struct Cos {
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)
};
};
};
struct Cosh {
template <typename T> T operator()(T x) { return metal::precise::cosh(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)
};
};
};
struct Erf {
template <typename T> T operator()(T x) { return static_cast<T>(erf(static_cast<float>(x))); };
};
struct ErfInv {
template <typename T> T operator()(T x) { return static_cast<T>(erfinv(static_cast<float>(x))); };
};
struct Exp {
template <typename T> T operator()(T x) { return metal::precise::exp(x); };
template <> complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
}
};
struct Log {
template <typename T> T operator()(T x) { return metal::precise::log(x); };
};
struct Log2 {
template <typename T> T operator()(T x) { return metal::precise::log2(x); };
};
struct Log10 {
template <typename T> T operator()(T x) { return metal::precise::log10(x); };
};
struct Log1p {
template <typename T> T operator()(T x) { return log1p(x); };
};
struct LogicalNot {
template <typename T> T operator()(T x) { return !x; };
};
struct Negative {
template <typename T> T operator()(T x) { return -x; };
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
};
struct Sign {
template <typename T> T operator()(T x) { return (x > T(0)) - (x < T(0)); };
template <> uint32_t operator()(uint32_t x) { return x != 0; };
};
struct Sin {
template <typename T> T operator()(T x) { return metal::precise::sin(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)
};
};
};
struct Sinh {
template <typename T> T operator()(T x) { return metal::precise::sinh(x); };
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)
};
};
};
struct Square {
template <typename T> T operator()(T x) { return x * x; };
};
struct Sqrt {
template <typename T> T operator()(T x) { return metal::precise::sqrt(x); };
};
struct Rsqrt {
template <typename T> T operator()(T x) { return metal::precise::rsqrt(x); };
};
struct Tan {
template <typename T> T operator()(T x) { return metal::precise::tan(x); };
template <>
complex64_t operator()(complex64_t x) {
float tan_a = metal::precise::tan(x.real);
float tanh_b = metal::precise::tanh(x.imag);
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {
(tan_a - tanh_b * t1) / denom,
(tanh_b + tan_a * t1) / denom
};
};
};
struct Tanh {
template <typename T> T operator()(T x) { return metal::precise::tanh(x); };
template <>
complex64_t operator()(complex64_t x) {
float tanh_a = metal::precise::tanh(x.real);
float tan_b = metal::precise::tan(x.imag);
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {
(tanh_a + tan_b * t1) / denom,
(tan_b - tanh_a * t1) / denom
};
};
};
template <typename T, typename Op>
[[kernel]] void unary_op_v(
device const T* in,
device T* out,
uint index [[thread_position_in_grid]]) {
out[index] = Op()(in[index]);
}
template <typename T, typename Op>
[[kernel]] void unary_op_g(
device const T* in,
device T* out,
device const int* in_shape,
device const size_t* in_strides,
device const int& ndim,
uint index [[thread_position_in_grid]]) {
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
out[index] = Op()(in[idx]);
}
#define instantiate_unary_v(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void unary_op_v<type, op>( \
device const type* in, \
device type* out, \
uint index [[thread_position_in_grid]]);
#define instantiate_unary_g(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void unary_op_g<type, op>( \
device const type* in, \
device type* out, \
device const int* in_shape, \
device const size_t* in_strides, \
device const int& ndim, \
uint index [[thread_position_in_grid]]);
#define instantiate_unary_all(name, tname, type, op) \
instantiate_unary_v("v" #name #tname, type, op) \
instantiate_unary_g("g" #name #tname, type, op)
#define instantiate_unary_float(name, op) \
instantiate_unary_all(name, float16, half, op) \
instantiate_unary_all(name, float32, float, op) \
instantiate_unary_all(name, bfloat16, bfloat16_t, op) \
#define instantiate_unary_types(name, op) \
instantiate_unary_all(name, bool_, bool, op) \
instantiate_unary_all(name, uint8, uint8_t, op) \
instantiate_unary_all(name, uint16, uint16_t, op) \
instantiate_unary_all(name, uint32, uint32_t, op) \
instantiate_unary_all(name, uint64, uint64_t, op) \
instantiate_unary_all(name, int8, int8_t, op) \
instantiate_unary_all(name, int16, int16_t, op) \
instantiate_unary_all(name, int32, int32_t, op) \
instantiate_unary_all(name, int64, int64_t, op) \
instantiate_unary_float(name, op)
instantiate_unary_types(abs, Abs)
instantiate_unary_float(arccos, ArcCos)
instantiate_unary_float(arccosh, ArcCosh)
instantiate_unary_float(arcsin, ArcSin)
instantiate_unary_float(arcsinh, ArcSinh)
instantiate_unary_float(arctan, ArcTan)
instantiate_unary_float(arctanh, ArcTanh)
instantiate_unary_float(cos, Cos)
instantiate_unary_float(cosh, Cosh)
instantiate_unary_float(exp, Exp)
instantiate_unary_float(log, Log)
instantiate_unary_float(log2, Log2)
instantiate_unary_float(log10, Log10)
instantiate_unary_float(log1p, Log1p)
instantiate_unary_types(neg, Negative)
instantiate_unary_float(sigmoid, Sigmoid)
instantiate_unary_float(erf, Erf)
instantiate_unary_float(erfinv, ErfInv)
instantiate_unary_types(sign, Sign)
instantiate_unary_float(sin, Sin)
instantiate_unary_float(sinh, Sinh)
instantiate_unary_types(square, Square)
instantiate_unary_float(sqrt, Sqrt)
instantiate_unary_float(rsqrt, Rsqrt)
instantiate_unary_float(tan, Tan)
instantiate_unary_float(tanh, Tanh)
instantiate_unary_all(abs, complex64, complex64_t, Abs)
instantiate_unary_all(cos, complex64, complex64_t, Cos)
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
instantiate_unary_all(exp, complex64, complex64_t, Exp)
instantiate_unary_all(neg, complex64, complex64_t, Negative)
instantiate_unary_all(sin, complex64, complex64_t, Sin)
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
instantiate_unary_all(tan, complex64, complex64_t, Tan)
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
instantiate_unary_all(lnot, bool_, bool, LogicalNot)

View File

@ -0,0 +1,369 @@
#include <algorithm>
#include <cassert>
#include <sstream>
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
//////////////////////////////////////////////////////////////////////
// Case wise reduce dispatch
//////////////////////////////////////////////////////////////////////
namespace {
// All Reduce
void all_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
// Get kernel and encode buffers
size_t in_size = in.size();
auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
// Set grid dimensions
// We make sure each thread has enough to do by making it read in
// atleast n_reads inputs
int n_reads = REDUCE_N_READS;
// mod_in_size gives us the groups of n_reads needed to go over the entire
// input
uint mod_in_size = (in_size + n_reads - 1) / n_reads;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
// If the number of thread groups needed exceeds 1024, we reuse threads groups
uint n_thread_groups =
(mod_in_size + thread_group_size - 1) / thread_group_size;
n_thread_groups = std::min(n_thread_groups, 1024u);
uint nthreads = n_thread_groups * thread_group_size;
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void row_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
const std::vector<int>& axes_,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in));
int n_reads = REDUCE_N_READS;
size_t reduction_size = in.size() / out.size();
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
// Each thread group is responsible for 1 output
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Align thread group size with simd_size
uint simd_size = kernel->threadExecutionWidth();
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
// Launch enough thread groups for each output
size_t n_threads = out.size() * thread_group_size;
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void col_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
const std::vector<int>& axes_,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
std::ostringstream kernel_name;
bool encode_in_shape = false;
bool encode_ndim = false;
// If the slowest moving axis can be merged into the reductions,
// we call the column reduce kernel
// In this case, a linear index in the output corresponds to the
// linear index in the input where the reduction starts
if (axes_[axes_.size() - 1] == (axes_.size() - 1)) {
kernel_name << "col_reduce_" << op_name << type_to_name(in);
}
// Otherwise, while all the reduction axes can be merged, the mapping between
// indices in the output and input require resolving using shapes and strides
else {
kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in);
encode_in_shape = true;
// We check for a viable template with the required number of dimensions
// we only care about encoding non-reduced shapes and strides in the input
size_t non_reducing_dims = in.ndim() - axes_.size();
if (non_reducing_dims >= 1 &&
non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) {
kernel_name << "_dim_" << non_reducing_dims;
} else {
encode_ndim = true;
}
}
auto kernel = d.get_kernel(kernel_name.str());
size_t in_size = in.size();
size_t out_size = out.size();
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
// Calculate the number of inputs to reduce and the stride b/w them
size_t reduction_size = 1;
size_t in_ndim = in.ndim();
size_t reduction_stride = in_size;
for (int i : axes_) {
reduction_size *= in.shape(i);
reduction_stride = std::min(reduction_stride, in.strides()[i]);
}
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
if (encode_in_shape) {
// Obtain the non-reducing shape and strides of the input to encode
std::vector<int> inp_shape_mod;
std::vector<size_t> inp_strides_mod;
for (size_t i = 0, j = 0; i < in.ndim(); i++) {
if (j < axes_.size() && axes_[j] == i) {
j++;
} else {
inp_shape_mod.push_back(in.shape(i));
inp_strides_mod.push_back(in.strides()[i]);
}
}
size_t ndim = inp_shape_mod.size();
compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6);
if (encode_ndim) {
compute_encoder->setBytes(&ndim, sizeof(size_t), 7);
}
}
// Select block dimensions
// Each thread reads 16 inputs to give it more work
uint n_inputs_per_thread = REDUCE_N_READS;
uint n_threads_per_output =
(reduction_size + n_inputs_per_thread - 1) / n_inputs_per_thread;
// We spread outputs over the x dimension and inputs over the y dimension
// Threads with the same lid.x in a given threadgroup work on the same
// output and each thread in the y dimension accumlates for that output
uint threadgroup_dim_x = std::min(out_size, 128ul);
uint threadgroup_dim_y =
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y);
uint n_threadgroups_x =
(out_size + threadgroup_dim_x - 1) / threadgroup_dim_x;
uint n_threadgroups_y =
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1);
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
// We set shared memory to be exploited here for reductions within a
// threadgroup - each thread must be able to update its accumulated output
// Note: Each threadgroup should have 32kB of data in threadgroup memory
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
// This should be fine for floats, but we might need to revisit
// if we ever come to doubles. In that case, we should also cut
// down the number of threads we launch in a threadgroup
compute_encoder->setThreadgroupMemoryLength(
threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void general_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
const std::vector<int>& axes_,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d) {
bool encode_ndim = true;
std::ostringstream kernel_name;
kernel_name << "general_reduce_" << op_name << type_to_name(in);
// Check for specialzed kernels for input ndim
if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) {
kernel_name << "_dim_" << in.ndim();
encode_ndim = false;
}
auto kernel = d.get_kernel(kernel_name.str());
size_t in_size = in.size();
size_t ndim = in.ndim();
// We set the reducing strides to 0 to induce collisions for the reduction
std::vector<size_t> out_strides(ndim);
size_t stride = 1;
for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) {
if (j >= 0 && axes_[j] == i) {
out_strides[i] = 0;
--j;
} else {
out_strides[i] = stride;
stride *= in.shape(i);
}
}
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
if (encode_ndim) {
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
}
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > in_size) {
thread_group_size = in_size;
}
size_t nthreads = in_size;
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
} // namespace
//////////////////////////////////////////////////////////////////////
// Main reduce dispatch
//////////////////////////////////////////////////////////////////////
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// TODO: Allow specific row and column reductions with types disabled
// due to atomics ?
if (size_of(in.dtype()) == 8) {
std::ostringstream msg;
msg << "[Reduce::eval_gpu] Does not support " << in.dtype();
throw std::runtime_error(msg.str());
}
// Make sure no identity reductions trickle down here
assert(!axes_.empty());
// Continue with reduction operation
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::string op_name;
switch (reduce_type_) {
case Reduce::And:
op_name = "and";
break;
case Reduce::Or:
op_name = "or";
break;
case Reduce::Sum:
op_name = "sum";
break;
case Reduce::Prod:
op_name = out.dtype() == bool_ ? "and" : "prod";
break;
case Reduce::Min:
op_name = out.dtype() == bool_ ? "and" : "min_";
break;
case Reduce::Max:
op_name = out.dtype() == bool_ ? "or" : "max_";
break;
}
// Initialize output
auto& s = stream();
auto& d = metal::device(s.device);
auto compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel("i" + op_name + type_to_name(out));
size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, out, 0);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
// Reduce
{
// Check for contiguous data
if (in.size() == in.data_size() &&
(in.flags().row_contiguous || in.flags().col_contiguous)) {
// Go to all reduce if reducing over all axes
if (axes_.size() == in.ndim()) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
return;
}
// Use specialized kernels if the input is row contiguous and
// the reducing axes can be merged into one
else if (
in.flags().row_contiguous && in.strides().back() == 1 &&
(axes_.back() - axes_.front()) == axes_.size() - 1) {
// If the fastest moving axis is being reduced, go to row reduce
if (axes_[0] == (in.ndim() - axes_.size())) {
row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
return;
}
// Otherwise go to to generalized strided reduce
// Note: bool isn't support here yet due to the use of atomics
// once that is updated, this should be the else condition of this
// branch
else if (in.dtype() != bool_) {
col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
return;
}
}
}
// Fall back to the general case
general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
}
}
} // namespace mlx::core

View File

@ -0,0 +1,15 @@
#include "mlx/allocator.h"
namespace mlx::core::allocator {
Allocator& allocator() {
static CommonAllocator allocator_;
return allocator_;
}
void* Buffer::raw_ptr() {
return ptr_;
}
} // namespace mlx::core::allocator

149
mlx/fft.h Normal file
View File

@ -0,0 +1,149 @@
#pragma once
#include <variant>
#include "array.h"
#include "device.h"
#include "stream.h"
namespace mlx::core::fft {
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
/** Compute the n-dimensional Fourier Transform. */
array fftn(
const array& a,
const std::vector<int>& n,
const std::vector<int>& axes,
StreamOrDevice s = {});
array fftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
array fftn(const array& a, StreamOrDevice s = {});
/** Compute the n-dimensional inverse Fourier Transform. */
array ifftn(
const array& a,
const std::vector<int>& n,
const std::vector<int>& axes,
StreamOrDevice s = {});
array ifftn(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
array ifftn(const array& a, StreamOrDevice s = {});
/** Compute the one-dimensional Fourier Transform. */
inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) {
return fftn(a, {n}, {axis}, s);
}
inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) {
return fftn(a, {axis}, s);
}
/** Compute the one-dimensional inverse Fourier Transform. */
inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) {
return ifftn(a, {n}, {axis}, s);
}
inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) {
return ifftn(a, {axis}, s);
}
/** Compute the two-dimensional Fourier Transform. */
inline array fft2(
const array& a,
const std::vector<int>& n,
const std::vector<int>& axes,
StreamOrDevice s = {}) {
return fftn(a, n, axes, s);
}
inline array fft2(
const array& a,
const std::vector<int>& axes = {-2, -1},
StreamOrDevice s = {}) {
return fftn(a, axes, s);
}
/** Compute the two-dimensional inverse Fourier Transform. */
inline array ifft2(
const array& a,
const std::vector<int>& n,
const std::vector<int>& axes,
StreamOrDevice s = {}) {
return ifftn(a, n, axes, s);
}
inline array ifft2(
const array& a,
const std::vector<int>& axes = {-2, -1},
StreamOrDevice s = {}) {
return ifftn(a, axes, s);
}
/** Compute the n-dimensional Fourier Transform on a real input. */
array rfftn(
const array& a,
const std::vector<int>& n,
const std::vector<int>& axes,
StreamOrDevice s = {});
array rfftn(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
array rfftn(const array& a, StreamOrDevice s = {});
/** Compute the n-dimensional inverse of `rfftn`. */
array irfftn(
const array& a,
const std::vector<int>& n,
const std::vector<int>& axes,
StreamOrDevice s = {});
array irfftn(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
array irfftn(const array& a, StreamOrDevice s = {});
/** Compute the one-dimensional Fourier Transform on a real input. */
inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
return rfftn(a, {n}, {axis}, s);
}
inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
return rfftn(a, {axis}, s);
}
/** Compute the one-dimensional inverse of `rfft`. */
inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
return irfftn(a, {n}, {axis}, s);
}
inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
return irfftn(a, {axis}, s);
}
/** Compute the two-dimensional Fourier Transform on a real input. */
inline array rfft2(
const array& a,
const std::vector<int>& n,
const std::vector<int>& axes,
StreamOrDevice s = {}) {
return rfftn(a, n, axes, s);
}
inline array rfft2(
const array& a,
const std::vector<int>& axes = {-2, -1},
StreamOrDevice s = {}) {
return rfftn(a, axes, s);
}
/** Compute the two-dimensional inverse of `rfft2`. */
inline array irfft2(
const array& a,
const std::vector<int>& n,
const std::vector<int>& axes,
StreamOrDevice s = {}) {
return irfftn(a, n, axes, s);
}
inline array irfft2(
const array& a,
const std::vector<int>& axes = {-2, -1},
StreamOrDevice s = {}) {
return irfftn(a, axes, s);
}
} // namespace mlx::core::fft

240
mlx/load.cpp Normal file
View File

@ -0,0 +1,240 @@
#include <algorithm>
#include <cstring>
#include <fstream>
#include <limits>
#include <sstream>
#include "mlx/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
// Adapted from
// https://github.com/angeloskath/supervised-lda/blob/master/include/ldaplusplus/NumpyFormat.hpp
namespace mlx::core {
namespace {
static constexpr uint8_t MAGIC[] = {
0x93,
0x4e,
0x55,
0x4d,
0x50,
0x59,
};
inline bool is_big_endian_() {
union ByteOrder {
int32_t i;
uint8_t c[4];
};
ByteOrder b = {0x01234567};
return b.c[0] == 0x01;
}
} // namespace
/** Save array to out stream in .npy format */
void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
////////////////////////////////////////////////////////
// Check array
a.eval(retain_graph);
if (a.nbytes() == 0) {
throw std::invalid_argument("[save] cannot serialize an empty array");
}
if (!a.flags().contiguous) {
throw std::invalid_argument(
"[save] cannot serialize a non-contiguous array");
}
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {
throw std::runtime_error("[save] Failed to open " + out_stream->label());
}
////////////////////////////////////////////////////////
// Prepare header
std::ostringstream magic_ver_len;
magic_ver_len.write(reinterpret_cast<const char*>(MAGIC), 6);
std::string fortran_order = a.flags().col_contiguous ? "True" : "False";
std::ostringstream header;
header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "',"
<< " 'fortran_order': " << fortran_order << ","
<< " 'shape': (";
for (auto i : a.shape()) {
header << i << ", ";
}
header << ")}";
size_t header_len = static_cast<size_t>(header.tellp());
bool is_v1 = header_len + 15 < std::numeric_limits<uint16_t>::max();
// Pad out magic + version + header_len + header + \n to be divisible by 16
size_t padding = (6 + 2 + (2 + 2 * is_v1) + header_len + 1) % 16;
header << std::string(padding, ' ') << '\n';
if (is_v1) {
magic_ver_len << (char)0x01 << (char)0x00;
uint16_t v1_header_len = header.tellp();
const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);
if (!is_big_endian_()) {
magic_ver_len.write(len_bytes, 2);
} else {
magic_ver_len.write(len_bytes + 1, 1);
magic_ver_len.write(len_bytes, 1);
}
} else {
magic_ver_len << (char)0x02 << (char)0x00;
uint32_t v2_header_len = header.tellp();
const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);
if (!is_big_endian_()) {
magic_ver_len.write(len_bytes, 4);
} else {
magic_ver_len.write(len_bytes + 3, 1);
magic_ver_len.write(len_bytes + 2, 1);
magic_ver_len.write(len_bytes + 1, 1);
magic_ver_len.write(len_bytes, 1);
}
}
////////////////////////////////////////////////////////
// Serialize array
out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length());
out_stream->write(header.str().c_str(), header.str().length());
out_stream->write(a.data<char>(), a.nbytes());
return;
}
/** Save array to file in .npy format */
void save(const std::string& file_, array a, bool retain_graph) {
// Open and check file
std::string file = file_;
// Add .npy to file name if it is not there
if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy")
file += ".npy";
// Serialize array
save(std::make_shared<io::FileWriter>(file), a, retain_graph);
}
/** Load array from reader in .npy format */
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
////////////////////////////////////////////////////////
// Open and check file
if (!in_stream->good() || !in_stream->is_open()) {
throw std::runtime_error("[load] Failed to open " + in_stream->label());
}
////////////////////////////////////////////////////////
// Read header and prepare array details
// Read and check magic
char read_magic_and_ver[8];
in_stream->read(read_magic_and_ver, 8);
if (std::memcmp(read_magic_and_ver, MAGIC, 6) != 0) {
throw std::runtime_error("[load] Invalid header in " + in_stream->label());
}
// Read and check version
if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) {
throw std::runtime_error(
"[load] Unsupport npy format version in " + in_stream->label());
}
// Read header len and header
int header_len_size = read_magic_and_ver[6] == 1 ? 2 : 4;
size_t header_len;
if (header_len_size == 2) {
uint16_t v1_header_len;
in_stream->read(reinterpret_cast<char*>(&v1_header_len), header_len_size);
header_len = v1_header_len;
} else {
uint32_t v2_header_len;
in_stream->read(reinterpret_cast<char*>(&v2_header_len), header_len_size);
header_len = v2_header_len;
}
// Read the header
std::vector<char> buffer(header_len + 1);
in_stream->read(&buffer[0], header_len);
buffer[header_len] = 0;
std::string header(&buffer[0]);
// Read data type from header
std::string dtype_str = header.substr(11, 3);
bool read_is_big_endian = dtype_str[0] == '>';
Dtype dtype = dtype_from_array_protocol(dtype_str);
// Read contiguity order
bool col_contiguous = header[34] == 'T';
// Read array shape from header
std::vector<int> shape;
size_t st = header.find_last_of('(') + 1;
size_t ed = header.find_last_of(')');
std::string shape_str = header.substr(st, ed - st);
while (!shape_str.empty()) {
// Read current number and get position of comma
size_t pos;
int dim = std::stoi(shape_str, &pos);
shape.push_back(dim);
// Skip the comma and space and read the next number
if (pos + 2 <= shape_str.length())
shape_str = shape_str.substr(pos + 2);
else {
shape_str = shape_str.substr(pos);
if (!shape_str.empty() && shape_str != " " && shape_str != ",") {
throw std::runtime_error(
"[load] Unknown error while parsing header in " +
in_stream->label());
}
shape_str = "";
}
}
////////////////////////////////////////////////////////
// Build primitive
size_t offset = 8 + header_len_size + header.length();
bool swap_endianness = read_is_big_endian != is_big_endian_();
if (col_contiguous) {
std::reverse(shape.begin(), shape.end());
}
auto loaded_array = array(
shape,
dtype,
std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness),
std::vector<array>{});
if (col_contiguous) {
loaded_array = transpose(loaded_array, s);
}
return loaded_array;
}
/** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s) {
return load(std::make_shared<io::FileReader>(file), s);
}
} // namespace mlx::core

2265
mlx/primitives.cpp Normal file

File diff suppressed because it is too large Load Diff

778
mlx/transforms.cpp Normal file
View File

@ -0,0 +1,778 @@
#include <algorithm>
#include <future>
#include <map>
#include <numeric>
#include <set>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include "mlx/backend/metal/metal.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
namespace mlx::core {
void simplify(const std::vector<array>& outputs) {
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
parents_map;
// Helpers to identify identical scalars
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
auto is_scalar = [](const array& a) {
return a.is_evaled() && a.ndim() == 0;
};
auto get_scalar_rep = [](const array& a) {
uint64_t v = 0;
int dtype;
switch (a.dtype().size) {
case 1:
v = *a.data<uint8_t>();
break;
case 4:
v = *a.data<uint32_t>();
break;
case 8:
v = *a.data<uint64_t>();
break;
}
return std::make_pair(v, a.dtype().val);
};
// DFS the graph to log the parents
recurse = [&](const array& a) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
for (int i = 0; i < a.inputs().size(); i++) {
auto& in = a.inputs()[i];
parents_map[in.id()].push_back({a, i});
recurse(in);
}
cache.insert(id);
tape.push(a);
if (is_scalar(a)) {
scalars.insert({get_scalar_rep(a), a});
}
};
for (auto& a : outputs) {
recurse(a);
}
// Helper that fuses two arrays in the graph by setting the parents of the
// source to point to the destination
auto fuse = [&](array& dst, array& src) {
auto src_parents = parents_map.find(src.id());
if (src_parents == parents_map.end()) {
return;
}
auto& pairs = parents_map[dst.id()];
for (auto& parent : src_parents->second) {
parent.first.editable_inputs()[parent.second] = dst;
pairs.push_back(parent);
}
};
// Walk the graph
cache.clear();
// Depth-1 array equivalence check.
auto array_equivalent = [](const array& a, const array& b) {
if (!a.has_primitive() || !b.has_primitive()) {
return false;
}
const auto& pa = a.primitive();
const auto& pb = b.primitive();
if (typeid(pa) != typeid(pb)) {
return false;
}
if (a.inputs().size() != b.inputs().size()) {
return false;
}
for (int i = 0; i < a.inputs().size(); i++) {
if (a.inputs()[i].id() != b.inputs()[i].id()) {
return false;
}
}
return pa.is_equivalent(pb);
};
while (!tape.empty()) {
auto arr = std::move(tape.front());
tape.pop();
if (cache.find(arr.id()) != cache.end()) {
continue;
}
// Check if we can fuse scalars
if (is_scalar(arr)) {
auto scalar = scalars.find(get_scalar_rep(arr));
if (scalar->second.id() != arr.id()) {
fuse(scalar->second, arr);
arr = scalar->second;
}
}
// Check if we can fuse the parents of this array
auto parents = parents_map.find(arr.id());
if (parents != parents_map.end()) {
std::vector<bool> mask(parents->second.size(), false);
auto N = parents->second.size();
for (int i = 0; i < N; i++) {
if (mask[i]) {
continue;
}
for (int j = i + 1; j < N; j++) {
if (mask[j]) {
continue;
}
auto& src = parents->second[j].first;
auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) {
cache.insert(src.id());
fuse(dst, src);
mask[j] = true;
}
}
}
}
}
}
void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
if (!retain_graph) {
for (auto& out : outputs) {
if (out.has_primitive() && out.is_tracer()) {
throw std::invalid_argument(
"[eval] Illegal to eval an array during "
"function transform without graph retention.");
}
}
}
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
recurse = [&](const array& a) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
for (auto in : a.inputs()) {
recurse(in);
// If one of the inputs is being computed on a different
// stream, we need to manage the dependency.
if (!in.is_evaled()) {
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.id(), std::shared_future<void>{}});
}
}
}
cache.insert(id);
if (!a.is_evaled() || (!retain_graph && a.has_primitive())) {
if (!a.has_primitive()) {
throw std::invalid_argument(
"[eval] Attempting to eval an array without a primitive.");
}
tape.push(a);
}
};
for (auto& arr : outputs) {
if (!arr.is_evaled() || (!retain_graph && arr.has_primitive())) {
recurse(arr);
// Insert a dependency for every output to synchronize
// with at the end.
if (!arr.is_evaled()) {
deps.insert({arr.id(), std::shared_future<void>{}});
}
}
}
while (!tape.empty()) {
auto arr = std::move(tape.front());
tape.pop();
if (arr.is_evaled()) {
if (!retain_graph && arr.has_primitive()) {
arr.detach();
}
continue;
}
auto stream = arr.primitive().stream();
std::vector<std::shared_future<void>> arr_deps;
for (auto& in : arr.inputs()) {
if (auto it = deps.find(in.id()); it != deps.end()) {
arr_deps.push_back(it->second);
}
}
std::shared_ptr<std::promise<void>> p{nullptr};
if (auto it = deps.find(arr.id()); it != deps.end()) {
p = std::make_unique<std::promise<void>>();
it->second = p->get_future().share();
}
if (arr.primitive().device() == Device::gpu) {
if (!metal::is_available()) {
throw std::runtime_error("Metal GPU is not available.");
}
scheduler::enqueue(
stream,
metal::make_task(
arr, std::move(arr_deps), std::move(p), retain_graph));
} else {
auto task = [retain_graph,
arr,
stream,
arr_deps = std::move(arr_deps),
p = std::move(p)]() mutable {
for (auto& d : arr_deps) {
d.wait();
}
scheduler::notify_new_task(stream);
arr.primitive().eval_cpu(arr.inputs(), arr);
if (!retain_graph) {
arr.detach();
}
if (p) {
p->set_value();
}
scheduler::notify_task_completion(stream);
};
scheduler::enqueue(stream, std::move(task));
}
}
for (auto& arr : outputs) {
if (auto it = deps.find(arr.id()); it != deps.end()) {
it->second.wait();
}
}
}
std::pair<std::vector<array>, std::vector<array>> vjp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& cotans) {
// Make tracers from given primals
std::vector<array> primals_;
for (auto& p : primals) {
auto s = p.has_primitive() ? p.primitive().stream()
: default_stream(default_device());
primals_.push_back(copy(p, s)); // Does not do a deep copy
primals_.back().set_tracer(true);
}
// Pass tracer primals through the function
// Any variables that depend on the primals are marked as tracers
auto outputs = fun(primals_);
// Map outputs to passed cotans while ignoring the outputs
// that have stop_gradient called on them
int cotan_index = 0;
std::vector<std::pair<int, int>> output_cotan_pairs;
for (int i = 0; i < outputs.size(); ++i) {
auto& out = outputs[i];
if (out.has_primitive()) {
if (auto& p = out.primitive(); typeid(p) == typeid(StopGradient)) {
continue;
}
}
if (cotan_index >= cotans.size()) {
throw std::invalid_argument(
"[vjp] Number of outputs with gradient does not match number of cotangents.");
}
if (out.shape() != cotans[cotan_index].shape()) {
throw std::invalid_argument(
"[vjp] Output shape does not match shape of cotangent.");
}
output_cotan_pairs.emplace_back(i, cotan_index++);
}
// Topologically sort the compute graph, record outputs
// in the tape if a gradient is needed.
std::unordered_set<std::uintptr_t> cache;
std::unordered_set<std::uintptr_t> calc_grad;
for (auto& primal : primals_) {
primal.set_tracer(false);
calc_grad.insert(primal.id());
cache.insert(primal.id());
}
std::vector<array> tape;
std::function<void(array&)> recurse;
recurse = [&](auto& a) {
auto id = a.id();
a.set_tracer(false);
// Check if visited and add to cache if not
if (auto inserted = cache.insert(id); !inserted.second) {
return;
}
for (auto& input : a.editable_inputs()) {
recurse(input);
}
// Stop grad
if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) {
return;
}
// Calculate gradient if any inputs require gradient
for (auto& input : a.inputs()) {
if (calc_grad.find(input.id()) != calc_grad.end()) {
tape.push_back(a);
calc_grad.insert(id);
break;
}
}
};
for (auto& out : outputs) {
recurse(out);
}
// Run the tape backwards, computing vector-jacobian
// products for each primitive
std::unordered_map<std::uintptr_t, array> cotan_map;
for (auto [out_idx, cotan_idx] : output_cotan_pairs) {
cotan_map.insert({outputs[out_idx].id(), cotans[cotan_idx]});
}
for (auto it = tape.rbegin(); it != tape.rend(); ++it) {
auto& a = *it;
// Get the arguments whose gradients are needed
std::vector<int> argnums;
for (int i = 0; i < a.inputs().size(); ++i) {
if (calc_grad.find(a.inputs()[i].id()) != calc_grad.end()) {
argnums.push_back(i);
}
}
auto cotan_it = cotan_map.find(a.id());
if (cotan_it == cotan_map.end()) {
continue;
}
auto cotan = cotan_map.extract(cotan_it).mapped();
auto vjps = a.primitive().vjp(a.inputs(), cotan, argnums);
auto s = a.primitive().stream();
// Accumulate the vector-jacobian products for each input
for (int i = 0; i < argnums.size(); ++i) {
auto in_id = a.inputs()[argnums[i]].id();
if (auto cotan_it = cotan_map.find(in_id); cotan_it != cotan_map.end()) {
cotan_it->second = add(cotan_it->second, vjps[i], s);
} else {
cotan_map.insert({in_id, vjps[i]});
}
}
}
std::vector<array> vjps;
for (auto& primal : primals_) {
if (auto cotan_it = cotan_map.find(primal.id());
cotan_it != cotan_map.end()) {
vjps.push_back(cotan_it->second);
} else {
auto s = primal.has_primitive() ? primal.primitive().stream()
: default_stream(default_device());
vjps.push_back(zeros_like(primal, s));
}
}
return {outputs, vjps};
}
std::pair<array, array> vjp(
const std::function<array(const array&)>& fun,
const array& primal,
const array& cotan) {
auto vec_fun = [fun](const std::vector<array>& inputs) {
return std::vector<array>{fun(inputs[0])};
};
auto [outputs, vjps] = vjp(vec_fun, {primal}, {cotan});
return {outputs[0], vjps[0]};
}
std::pair<std::vector<array>, std::vector<array>> jvp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& tangents) {
if (primals.size() != tangents.size()) {
throw std::invalid_argument(
"[jvp] Number of inputs does not match number of tangents.");
}
for (int i = 0; i < primals.size(); ++i) {
if (primals[i].shape() != tangents[i].shape()) {
throw std::invalid_argument(
"[jvp] Input shape does not match shape of tangent.");
}
}
std::vector<array> primals_;
for (auto& p : primals) {
auto s = p.has_primitive() ? p.primitive().stream()
: default_stream(default_device());
primals_.push_back(copy(p, s)); // Does not do a deep copy
primals_.back().set_tracer(true);
}
auto outputs = fun(primals_);
// Topologically sort the compute graph, record outputs
// in the tape if a gradient is needed.
std::unordered_set<std::uintptr_t> cache;
std::unordered_set<std::uintptr_t> calc_grad;
for (auto& primal : primals_) {
primal.set_tracer(false);
calc_grad.insert(primal.id());
cache.insert(primal.id());
}
std::vector<array> tape;
std::function<void(array&)> recurse;
recurse = [&](auto& a) {
auto id = a.id();
a.set_tracer(false);
// Check if visited and add to cache if not
if (auto inserted = cache.insert(id); !inserted.second) {
return;
}
for (auto& input : a.editable_inputs()) {
recurse(input);
}
// Stop grad
if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) {
return;
}
// Calculate gradient if any inputs require gradient
for (auto& input : a.inputs()) {
if (calc_grad.find(input.id()) != calc_grad.end()) {
tape.push_back(a);
calc_grad.insert(id);
break;
}
}
};
for (auto& out : outputs) {
recurse(out);
}
std::unordered_map<std::uintptr_t, array> tan_map;
for (int i = 0; i < primals_.size(); ++i) {
tan_map.insert({primals_[i].id(), tangents[i]});
}
for (auto& a : tape) {
// Get the arguments used in the jvp
std::vector<int> argnums;
std::vector<array> tangents;
for (int i = 0; i < a.inputs().size(); ++i) {
if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) {
argnums.push_back(i);
tangents.push_back(it->second);
}
}
auto jvp = a.primitive().jvp(a.inputs(), tangents, argnums);
tan_map.insert({a.id(), jvp});
}
std::vector<array> jvps;
for (auto& out : outputs) {
if (auto it = tan_map.find(out.id()); it != tan_map.end()) {
jvps.push_back(it->second);
} else {
auto s = out.has_primitive() ? out.primitive().stream()
: default_stream(default_device());
jvps.push_back(zeros_like(out, s));
}
}
return {outputs, jvps};
}
std::pair<array, array> jvp(
const std::function<array(const array&)>& fun,
const array& primal,
const array& tangent) {
auto vec_fun = [fun](const std::vector<array>& inputs) {
return std::vector<array>{fun(inputs[0])};
};
auto [outputs, jvps] = jvp(vec_fun, {primal}, {tangent});
return {outputs[0], jvps[0]};
}
ValueAndGradFn value_and_grad(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<int>& argnums) {
if (argnums.empty()) {
throw std::invalid_argument("[grad] Must specify at least one argument.");
}
return [fun, argnums](const std::vector<array>& inputs) {
std::set<int> args;
for (auto& arg : argnums) {
args.insert(arg < 0 ? arg + inputs.size() : arg);
}
if (args.size() != argnums.size()) {
throw std::invalid_argument(
"[grad] Repeat argument number not allowed in grad.");
}
if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) {
std::ostringstream msg;
msg << "[grad] Invalid argument number for function with "
<< inputs.size() << " inputs.";
throw std::invalid_argument(msg.str());
}
auto gfun = [&fun, &inputs, &args](const std::vector<array>& ginputs) {
std::vector<array> inputs_(inputs);
auto argit = args.begin();
for (int i = 0; i < ginputs.size(); ++i) {
inputs_[*argit] = ginputs[i];
++argit;
}
auto outputs = fun(inputs_);
for (int i = 1; i < outputs.size(); i++) {
auto& out = outputs[i];
auto s = out.has_primitive() ? out.primitive().stream()
: default_stream(default_device());
outputs[i] = stop_gradient(out, s);
}
return outputs;
};
std::vector<array> ginputs;
for (auto arg : args) {
ginputs.push_back(inputs[arg]);
}
// Set the incoming gradient as int32 so that it will be promoted to the
// appropriate floating point type op(int, floatXX) -> floatXX for most ops
auto [outputs, grads] = vjp(gfun, ginputs, {array(1)});
return std::make_pair(outputs, grads);
};
}
namespace detail {
std::pair<std::vector<array>, std::vector<array>> vmap_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs,
const std::vector<int>& in_axes) {
if (in_axes.size() != inputs.size()) {
throw std::invalid_argument(
"[vmap] The number of in axes must match the number of inputs.");
}
// Run the function on placeholder inputs
// to get the original graph
std::vector<array> s_inputs;
for (int i = 0; i < inputs.size(); ++i) {
if (in_axes[i] != -1) {
if (inputs[i].ndim() == 0) {
throw std::invalid_argument(
"[vmap] Cannot vmap an input with zero dimensions.");
}
if (in_axes[i] > inputs[i].ndim()) {
std::ostringstream msg;
msg << "[vmap] Axis " << in_axes[i] << " invalid for input with "
<< inputs[i].ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
std::vector<int> shape = inputs[i].shape();
shape.erase(shape.begin() + in_axes[i]);
array in(shape, inputs[i].dtype(), nullptr, {});
s_inputs.push_back(in);
s_inputs.back().set_tracer(true);
} else {
s_inputs.push_back(inputs[i]);
}
}
return {s_inputs, fun(s_inputs)};
}
std::vector<array> vmap_replace(
const std::vector<array>& inputs,
const std::vector<array>& s_inputs,
const std::vector<array>& s_outputs,
const std::vector<int>& in_axes,
const std::vector<int>& out_axes) {
if (out_axes.size() != s_outputs.size()) {
throw std::invalid_argument(
"[vmap] The number of out axes must match the number of outputs.");
}
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
std::unordered_set<std::uintptr_t> needs_vmap;
for (int i = 0; i < s_inputs.size(); ++i) {
if (in_axes[i] != -1) {
tmap.insert({s_inputs[i].id(), {inputs[i], in_axes[i]}});
needs_vmap.insert(s_inputs[i].id());
}
}
// Topologically sort the graph
std::unordered_set<std::uintptr_t> cache;
for (int i = 0; i < s_inputs.size(); ++i) {
auto in = s_inputs[i];
if (in_axes[i] != -1) {
in.set_tracer(false);
}
cache.insert(in.id());
}
std::vector<array> tape;
std::function<void(const array&)> recurse;
recurse = [&](const array& a) {
// Stop at inputs to the vmap function
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
for (auto& input : a.inputs()) {
recurse(input);
}
cache.insert(id);
for (auto& input : a.inputs()) {
if (needs_vmap.find(input.id()) != needs_vmap.end()) {
needs_vmap.insert(id);
tape.push_back(a);
tape.back().set_tracer(false);
break;
}
}
};
for (auto& out : s_outputs) {
recurse(out);
}
// Transform each primitive in the graph with
// its vmap implementation
for (auto& a : tape) {
std::vector<array> v_inputs;
std::vector<int> v_axes;
for (auto& in : a.inputs()) {
auto map_it = tmap.find(in.id());
if (map_it != tmap.end()) {
v_inputs.push_back(map_it->second.first);
v_axes.push_back(map_it->second.second);
} else {
v_inputs.push_back(in);
v_axes.push_back(-1);
}
}
auto out_and_axis = a.primitive().vmap(v_inputs, v_axes);
tmap.insert({a.id(), out_and_axis});
}
// Populate the outputs and make sure all the output axes are
// in the right place
std::vector<array> outputs;
for (int i = 0; i < s_outputs.size(); ++i) {
auto map_it = tmap.find(s_outputs[i].id());
if (map_it != tmap.end()) {
auto& [out, vdim] = map_it->second;
if (vdim != out_axes[i]) {
if (out_axes[i] >= out.ndim()) {
std::ostringstream msg;
msg << "[vmap] Axis " << out_axes[i] << " invalid for output with "
<< out.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
std::vector<int> reorder(out.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
reorder.erase(reorder.begin() + vdim);
reorder.insert(reorder.begin() + out_axes[i], vdim);
out = transpose(out, reorder);
}
outputs.push_back(out);
} else {
outputs.push_back(s_outputs[i]);
}
}
return outputs;
}
} // namespace detail
std::function<std::vector<array>(const std::vector<array>&)> vmap(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<int>& in_axes /* = {} */,
const std::vector<int>& out_axes /* = {} */) {
auto infer_axes = [](auto axes) {
return !axes.empty() &&
std::all_of(axes.begin(), axes.end(), [](int ax) { return ax < 0; });
};
if (infer_axes(in_axes) != infer_axes(out_axes)) {
throw std::invalid_argument(
"[vmap] Input (or output) axes must be "
"specified if output (or input) axes are.");
}
auto vfun = [fun, in_axes = in_axes, out_axes = out_axes](
const std::vector<array>& inputs) mutable {
if (in_axes.size() == 0) {
in_axes.resize(inputs.size(), 0);
}
auto [trace_inputs, trace_outputs] =
detail::vmap_trace(fun, inputs, in_axes);
if (out_axes.size() == 0) {
out_axes.resize(trace_outputs.size(), 0);
}
return detail::vmap_replace(
inputs, trace_inputs, trace_outputs, in_axes, out_axes);
};
return vfun;
}
std::function<array(const array&, const array&)> vmap(
const std::function<array(const array&, const array&)>& fun,
int in_axis_a /* = 0 */,
int in_axis_b /* = 0 */,
int out_axis /* = 0 */) {
auto vfun = vmap(
[in_axis_a, in_axis_b, out_axis, fun](const std::vector<array>& inputs) {
return std::vector<array>{fun(inputs[0], inputs[1])};
},
{in_axis_a, in_axis_b},
{out_axis});
return [vfun](const array& a, const array& b) { return vfun({a, b})[0]; };
}
std::function<array(const array&)> vmap(
const std::function<array(const array&)>& fun,
int in_axis /* = 0 */,
int out_axis /* = 0 */) {
auto vfun = vmap(
[in_axis, out_axis, fun](const std::vector<array>& inputs) {
return std::vector<array>{fun(inputs[0])};
},
{in_axis},
{out_axis});
return [vfun](const array& a) { return vfun({a})[0]; };
}
} // namespace mlx::core

185
mlx/transforms.h Normal file
View File

@ -0,0 +1,185 @@
#pragma once
#include "array.h"
namespace mlx::core {
/** Fuse equivalent arrays to avoid duplicate execution. */
void simplify(const std::vector<array>& outputs);
template <typename... Arrays>
void simplify(Arrays... outputs) {
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
}
void eval(const std::vector<array>& outputs, bool retain_graph = false);
template <typename... Arrays>
void eval(Arrays... outputs) {
eval(std::vector<array>{std::forward<Arrays>(outputs)...}, false);
}
/**
* Computes the output and vector-Jacobian product (VJP) of a function.
*
* Computes the vector-Jacobian product of the vector of cotangents with the
* Jacobian of the function evaluated at the primals. Returns a pair of
* vectors of output arrays and VJP arrays.
**/
std::pair<std::vector<array>, std::vector<array>> vjp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& cotangents);
/**
* Computes the output and vector-Jacobian product (VJP) of a unary function.
*/
std::pair<array, array> vjp(
const std::function<array(const array&)>& fun,
const array& primal,
const array& cotangent);
/**
* Computes the output and Jacobian-vector product (JVP) of a function.
*
* Computes the Jacobian-vector product of the Jacobian of the function
* evaluated at the primals with the vector of tangents. Returns a pair of
* vectors of output arrays and JVP arrays.
**/
std::pair<std::vector<array>, std::vector<array>> jvp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& tangents);
/**
* Computes the output and Jacobian-vector product (JVP) of a unary function.
*/
std::pair<array, array> jvp(
const std::function<array(const array&)>& fun,
const array& primal,
const array& tangent);
// Return type of general value_and_grad: a function which takes an input
// vector of arrays and returns a pair of vectors of arrays one for the
// values and one for the gradients wrt the first value.
using ValueAndGradFn =
std::function<std::pair<std::vector<array>, std::vector<array>>(
const std::vector<array>&)>;
using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
const std::vector<array>&)>;
/**
* Returns a function which computes the value and gradient of the input
* function with respect to a vector of input arrays.
**/
ValueAndGradFn value_and_grad(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<int>& argnums);
/**
* Returns a function which computes the value and gradient of the input
* function with repsect to a single input array.
**/
ValueAndGradFn inline value_and_grad(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
int argnum = 0) {
return value_and_grad(fun, std::vector<int>{argnum});
}
/**
* Returns a function which computes the value and gradient of the unary
* input function.
**/
std::function<std::pair<array, array>(const array&)> inline value_and_grad(
const std::function<array(const array&)>& fun) {
return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
}
SimpleValueAndGradFn inline value_and_grad(
const std::function<array(const std::vector<array>&)>& fun,
const std::vector<int>& argnums) {
return [fun, argnums](auto inputs) {
auto result = value_and_grad(
[fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
argnums)(inputs);
return std::make_pair(result.first[0], result.second);
};
}
SimpleValueAndGradFn inline value_and_grad(
const std::function<array(const std::vector<array>&)>& fun,
int argnum = 0) {
return value_and_grad(fun, std::vector<int>{argnum});
}
/**
* Returns a function which computes the gradient of the input function with
* respect to a vector of input arrays.
*
* The function being differentiated takes a vector of arrays and returns an
* array. The vector of `argnums` specifies which the arguments to compute
* the gradient with respect to. At least one argument must be specified.
**/
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
const std::function<array(const std::vector<array>&)>& fun,
const std::vector<int>& argnums) {
auto fn = value_and_grad(fun, argnums);
return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
}
/**
* Returns a function which computes the gradient of the input function with
* repsect to a single input array.
*
* The function being differentiated takes a vector of arrays and returns an
* array. The optional `argnum` index specifies which the argument to compute
* the gradient with respect to and defaults to 0.
**/
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
const std::function<array(const std::vector<array>&)>& fun,
int argnum = 0) {
return grad(fun, std::vector<int>{argnum});
}
/**
* Returns a function which computes the gradient of the unary input function.
**/
std::function<array(const array&)> inline grad(
const std::function<array(const array&)>& fun) {
auto fn = value_and_grad(fun);
return [fn](const array& input) { return fn(input).second; };
}
/**
* Automatically vectorize a unary function over the requested axes.
*/
std::function<array(const array&)> vmap(
const std::function<array(const array&)>& fun,
int in_axis = 0,
int out_axis = 0);
/**
* Automatically vectorize a binary function over the requested axes.
*/
std::function<array(const array&, const array&)> vmap(
const std::function<array(const array&, const array&)>& fun,
int in_axis_a = 0,
int in_axis_b = 0,
int out_axis = 0);
/**
* Automatically vectorize a function over the requested axes.
*
* The input function to `vmap` takes as an argument a vector of arrays and
* returns a vector of arrays. Optionally specify the axes to vectorize over
* with `in_axes` and `out_axes`, otherwise a default of 0 is used.
* Returns a vectorized function with the same signature as the input
* function.
*/
std::function<std::vector<array>(const std::vector<array>&)> vmap(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<int>& in_axes = {},
const std::vector<int>& out_axes = {});
} // namespace mlx::core

185
mlx/types/bf16.h Normal file
View File

@ -0,0 +1,185 @@
#pragma once
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <vector>
#define __MLX_BFLOAT_NAN__ 0x7FC0
namespace mlx::core {
namespace {
union float_bits_bf16 {
float f;
uint32_t u;
};
} // namespace
struct _MLX_BFloat16 {
uint16_t bits_;
// Default constructor
_MLX_BFloat16() = default;
// Default copy constructor
_MLX_BFloat16(_MLX_BFloat16 const&) = default;
// Appease std::vector<bool> for being special
_MLX_BFloat16& operator=(std::vector<bool>::reference x) {
bits_ = x;
return *this;
}
_MLX_BFloat16& operator=(const float& x) {
return (*this = _MLX_BFloat16(x));
}
// From float32
_MLX_BFloat16(const float& x) {
if (std::isnan(x)) {
bits_ = __MLX_BFLOAT_NAN__;
} else {
// Union
float_bits_bf16 in;
// Take bits
in.f = x;
// Round to nearest even
in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF);
// Take upper 16 bits
bits_ = in.u >> 16;
}
}
// To float32
operator float() const {
// Union
float_bits_bf16 out;
// Upper 16 bits are the data and lower 16 bits are 0s
out.u = ((uint32_t)bits_) << 16;
return out.f;
}
};
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
inline otype __operator__(atype lhs, btype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
} \
inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
// Operators
#define bfloat_binop(_op_, _operator_) \
bfloat_binop_base( \
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
bfloat_binop_helper(_op_, _operator_, float, float, float); \
bfloat_binop_helper(_op_, _operator_, double, double, double); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
bfloat_binop(+, operator+);
bfloat_binop(-, operator-);
bfloat_binop(*, operator*);
bfloat_binop(/, operator/);
#undef bfloat_binop
// Comparison ops
#define bfloat_compop(__op__, __operator__) \
bfloat_binop_base( \
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
bfloat_binop_helper(__op__, __operator__, bool, double, double); \
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
bfloat_compop(>, operator>);
bfloat_compop(<, operator<);
bfloat_compop(>=, operator>=);
bfloat_compop(<=, operator<=);
bfloat_compop(==, operator==);
bfloat_compop(!=, operator!=);
#undef bfloat_compop
// Negative
inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) {
return -static_cast<float>(lhs);
}
// Inplace ops
#define bfloat_inplace_op(__op__, __operator__) \
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \
lhs = lhs __op__ rhs; \
return lhs; \
} \
inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \
lhs = lhs __op__ rhs; \
return lhs; \
}
bfloat_inplace_op(+, operator+=);
bfloat_inplace_op(-, operator-=);
bfloat_inplace_op(*, operator*=);
bfloat_inplace_op(/, operator/=);
#undef bfloat_inplace_op
// Bitwise ops
#define bfloat_bitop(__op__, __operator__) \
inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \
_MLX_BFloat16 out; \
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
return out; \
} \
inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \
_MLX_BFloat16 out; \
out.bits_ = lhs.bits_ __op__ rhs; \
return out; \
} \
inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \
_MLX_BFloat16 out; \
out.bits_ = lhs __op__ rhs.bits_; \
return out; \
}
bfloat_bitop(|, operator|);
bfloat_bitop(&, operator&);
bfloat_bitop(^, operator^);
#undef bfloat_bitop
#define bfloat_inplace_bitop(__op__, __operator__) \
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
return lhs; \
} \
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \
lhs.bits_ = lhs.bits_ __op__ rhs; \
return lhs; \
}
bfloat_inplace_bitop(|, operator|=);
bfloat_inplace_bitop(&, operator&=);
bfloat_inplace_bitop(^, operator^=);
#undef bfloat_inplace_bitop
} // namespace mlx::core

View File

@ -0,0 +1,3 @@
from mlx.nn.layers import *
from mlx.nn import losses
from mlx.nn.utils import value_and_grad

View File

@ -0,0 +1,23 @@
from mlx.nn.layers.base import Module
from mlx.nn.layers.activations import (
GELU,
ReLU,
SiLU,
gelu,
gelu_approx,
gelu_fast_approx,
relu,
silu,
)
from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.transformer import (
MultiHeadAttention,
TransformerEncoder,
TransformerEncoderLayer,
)

View File

@ -0,0 +1,129 @@
import math
import mlx.core as mx
from mlx.nn.layers.base import Module
def _make_activation_module(f):
def decorator(klass):
klass.__doc__ = f.__doc__
klass.__call__ = lambda self, x: f(x)
return klass
return decorator
def relu(x):
"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``.
"""
return mx.maximum(x, 0)
def silu(x):
r"""Applies the Sigmoid Linear Unit.
Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is
the logistic sigmoid.
"""
return x * mx.sigmoid(x)
def gelu(x):
"""Applies the Gaussian Error Linear Units function.
.. math::
\\textrm{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Gaussian CDF.
See also :func:`gelu_approx` and :func:`gelu_fast_approx` for faster
approximations.
"""
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
def gelu_approx(x):
r"""An approximation to Gaussian Error Linear Unit.
See :func:`gelu` for the exact computation.
This function approximates ``gelu`` with a maximum absolute error :math:`<
0.0003` in the range :math:`[-6, 6]` using the following
.. math::
x = x \sigma\left(1.60033 x \left(1 + 0.0433603 x^2\right)\right)
where :math:`\sigma(\cdot)` is the logistic sigmoid.
"""
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
def gelu_fast_approx(x):
r"""A fast approximation to Gaussian Error Linear Unit.
See :func:`gelu` for the exact computation.
This function approximates ``gelu`` with a maximum absolute error :math:`<
0.015` in the range :math:`[-6, 6]` using the following
.. math::
x = x \sigma\left(1.773 x\right)
where :math:`\sigma(\cdot)` is the logistic sigmoid.
"""
return x * mx.sigmoid(1.773 * x)
@_make_activation_module(relu)
class ReLU(Module):
pass
@_make_activation_module(silu)
class SiLU(Module):
pass
class GELU(Module):
r"""Applies the Gaussian Error Linear Units.
.. math::
\textrm{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Gaussian CDF.
However, if ``approx`` is set to 'precise' or 'fast' it applies
.. math::
\textrm{GELUApprox}(x) &= x * \sigma\left(1.60033 * x \left(1 + 0.0433603 * x^2\right)\right) \\
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
respectively.
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
functional equivalents and information regarding error bounds.
Args:
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
"""
def __init__(self, approx="none"):
super().__init__()
if approx == "none":
self._act = gelu
elif approx == "precise":
self._act = gelu_approx
elif approx == "fast":
self._act = gelu_fast_approx
else:
raise ValueError(
f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given"
)
def __call__(self, x):
return self._act(x)

6
python/mlx/nn/losses.py Normal file
View File

@ -0,0 +1,6 @@
import mlx.core as mx
def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1):
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
return mx.logsumexp(logits, axis=axis) - score

136
python/mlx/utils.py Normal file
View File

@ -0,0 +1,136 @@
def tree_map(fn, tree, *rest):
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
returns a new collection with the results.
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
and the corresponding leaves are provided as extra positional arguments to
``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
than to :func:`map`.
.. code-block:: python
import mlx.nn as nn
from mlx.utils import tree_map
model = nn.Linear(10, 10)
print(model.parameters().keys())
# dict_keys(['weight', 'bias'])
# square the parameters
model.update(tree_map(lambda x: x*x, model.parameters()))
Args:
fn (Callable): The function that processes the leaves of the tree
tree (Any): The main python tree that will be iterated upon
rest (Tuple[Any]): Extra trees to be iterated together with tree
Returns:
A python tree with the new values returned by ``fn``.
"""
if isinstance(tree, list):
return [
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
]
elif isinstance(tree, tuple):
return tuple(
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
)
elif isinstance(tree, dict):
return {
k: tree_map(fn, child, *(r[k] for r in rest)) for k, child in tree.items()
}
else:
return fn(tree, *rest)
def tree_flatten(tree, prefix="", is_leaf=None):
"""Flattens a python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and
complexity.
.. code-block:: python
from mlx.utils import tree_flatten
print(tree_flatten([[[0]]]))
# [("0.0.0", 0)]
print(tree_flatten([[[0]]], ".hello"))
# [("hello.0.0.0", 0)]
.. note::
Dictionaries should have keys that are valid python identifiers.
Args:
tree (Any): The python tree to be flattened.
prefix (str): A prefix to use for the keys. The first character is
always discarded.
is_leaf (Callable): An optional callable that returns True if the
passed object is considered a leaf or False otherwise.
Returns:
List[Tuple[str, Any]]: The flat representation of the python tree.
"""
flat_tree = []
if is_leaf is None or not is_leaf(tree):
if isinstance(tree, (list, tuple)):
for i, t in enumerate(tree):
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
return flat_tree
if isinstance(tree, dict):
for k, t in tree.items():
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
return flat_tree
return [(prefix[1:], tree)]
def tree_unflatten(tree):
"""Recreate a python tree from its flat representation.
.. code-block:: python
from mlx.utils import tree_unflatten
d = tree_unflatten([("hello.world", 42)])
print(d)
# {"hello": {"world": 42}}
Args:
tree (List[Tuple[str, Any]]): The flat representation of a python tree.
For instance as returned by :meth:`tree_flatten`.
Returns:
A python tree.
"""
if len(tree) == 1 and tree[0][0] == "":
return tree[0][1]
try:
int(tree[0][0].split(".", maxsplit=1)[0])
is_list = True
except ValueError:
is_list = False
# collect children
children = {}
for key, value in tree:
current_idx, *next_idx = key.split(".", maxsplit=1)
next_idx = "" if not next_idx else next_idx[0]
if current_idx not in children:
children[current_idx] = []
children[current_idx].append((next_idx, value))
# recursively map them to the original container
if is_list:
keys = sorted((int(idx), idx) for idx in children.keys())
l = []
for i, k in keys:
while i > len(l):
l.append({})
l.append(tree_unflatten(children[k]))
return l
else:
return {k: tree_unflatten(v) for k, v in children.items()}

1071
python/src/array.cpp Normal file

File diff suppressed because it is too large Load Diff

42
python/src/device.cpp Normal file
View File

@ -0,0 +1,42 @@
#include <sstream>
#include <pybind11/pybind11.h>
#include "mlx/device.h"
#include "mlx/utils.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
void init_device(py::module_& m) {
py::enum_<Device::DeviceType>(m, "DeviceType")
.value("cpu", Device::DeviceType::cpu)
.value("gpu", Device::DeviceType::gpu)
.export_values()
.def(
"__eq__",
[](const Device::DeviceType& d1, const Device& d2) {
return d1 == d2;
},
py::prepend());
py::class_<Device>(m, "Device")
.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
.def_readonly("type", &Device::type)
.def(
"__repr__",
[](const Device& d) {
std::ostringstream os;
os << d;
return os.str();
})
.def("__eq__", [](const Device& d1, const Device& d2) {
return d1 == d2;
});
py::implicitly_convertible<Device::DeviceType, Device>();
m.def("default_device", &default_device);
m.def("set_default_device", &set_default_device, "device"_a);
}

12
python/src/metal.cpp Normal file
View File

@ -0,0 +1,12 @@
#include <pybind11/pybind11.h>
#include "mlx/backend/metal/metal.h"
namespace py = pybind11;
using namespace mlx::core;
void init_metal(py::module_& m) {
py::module_ metal = m.def_submodule("metal", "mlx.metal");
metal.def("is_available", &metal::is_available);
}

32
python/src/stream.cpp Normal file
View File

@ -0,0 +1,32 @@
#include <sstream>
#include <pybind11/pybind11.h>
#include "mlx/stream.h"
#include "mlx/utils.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
void init_stream(py::module_& m) {
py::class_<Stream>(m, "Stream")
.def(py::init<int, Device>(), "index"_a, "device"_a)
.def_readonly("device", &Stream::device)
.def(
"__repr__",
[](const Stream& s) {
std::ostringstream os;
os << s;
return os.str();
})
.def("__eq__", [](const Stream& s1, const Stream& s2) {
return s1 == s2;
});
py::implicitly_convertible<Device::DeviceType, Device>();
m.def("default_stream", &default_stream, "device"_a);
m.def("set_default_stream", &set_default_stream, "stream"_a);
m.def("new_stream", &new_stream, "device"_a);
}

188
python/tests/test_bf16.py Normal file
View File

@ -0,0 +1,188 @@
import unittest
from itertools import permutations
import math
import mlx.core as mx
import numpy as np
import mlx_tests
try:
import torch
has_torch = True
except ImportError as e:
has_torch = False
class TestBF16(mlx_tests.MLXTestCase):
def __test_ops(
self,
ref_op, # Function that outputs array_like
mlx_op, # Function that outputs array_like
np_args, # Numpy arguments
ref_transform=lambda x: x,
mlx_transform=lambda x: mx.array(x),
atol=1e-5,
):
ref_args = map(ref_transform, np_args)
mlx_args = map(mlx_transform, np_args)
r_ref = ref_op(*ref_args)
r_mlx = mlx_op(*mlx_args)
self.assertTrue(np.allclose(r_mlx, r_ref, atol=atol))
def __default_test(
self,
op,
np_args,
simple_transform=lambda x: x,
atol_np=1e-3,
atol_torch=1e-5,
np_kwargs=dict(),
mlx_kwargs=dict(),
torch_kwargs=dict(),
torch_op=None,
):
with self.subTest(reference="numpy"):
def np_transform(x):
x_mx_bf16 = mx.array(x).astype(mx.bfloat16)
x_mx_fp32 = x_mx_bf16.astype(mx.float32)
return np.asarray(x_mx_fp32)
def mlx_fn(*args):
out_bf16 = getattr(mx, op)(*args, **mlx_kwargs)
return np.asarray(out_bf16.astype(mx.float32))
def np_fn(*args):
out_fp32 = getattr(np, op)(*args, **np_kwargs)
return np_transform(out_fp32)
ref_op = np_fn
mlx_op = mlx_fn
ref_transform = lambda x: simple_transform(np_transform(x))
mlx_transform = lambda x: simple_transform(mx.array(x).astype(mx.bfloat16))
self.__test_ops(
ref_op,
mlx_op,
np_args,
ref_transform=ref_transform,
mlx_transform=mlx_transform,
atol=atol_np,
)
if has_torch:
with self.subTest(reference="torch"):
torch_op = op if torch_op is None else torch_op
def torch_fn(*args):
out_bf16 = getattr(torch, torch_op)(*args, **torch_kwargs)
return out_bf16.to(torch.float32).numpy()
ref_op = torch_fn
ref_transform = lambda x: simple_transform(
torch.from_numpy(x).to(torch.bfloat16)
)
self.__test_ops(
ref_op,
mlx_op,
np_args,
ref_transform=ref_transform,
mlx_transform=mlx_transform,
atol=atol_torch,
)
def test_unary_ops(self):
x = np.random.rand(18, 28, 38)
for op in ["abs", "exp", "log", "square", "sqrt"]:
with self.subTest(op=op):
np_args = (x.astype(np.float32),)
self.__default_test(op, np_args)
def test_binary_ops(self):
x = np.random.rand(18, 28, 38)
y = np.random.rand(18, 28, 38)
for op in ["add", "subtract", "multiply", "divide", "maximum", "minimum"]:
with self.subTest(op=op):
np_args = (
x.astype(np.float32),
y.astype(np.float32),
)
self.__default_test(op, np_args, simple_transform=lambda x: x)
self.__default_test(op, np_args, simple_transform=lambda x: x[:1])
self.__default_test(op, np_args, simple_transform=lambda x: x[:, :1])
def test_reduction_ops(self):
x = np.random.rand(18, 28, 38).astype(np.float32)
for op in ("min", "max"):
with self.subTest(op=op):
for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
with self.subTest(axes=axes):
np_args = (x.astype(np.float32),)
self.__default_test(
op,
np_args,
np_kwargs={"axis": axes},
mlx_kwargs={"axis": axes},
torch_kwargs={"dim": axes},
torch_op="a" + op,
)
def test_arg_reduction_ops(self):
data = np.random.rand(10, 12, 13).astype(np.float32)
x = mx.array(data).astype(mx.bfloat16)
data = np.asarray(x.astype(mx.float32))
for op in ["argmin", "argmax"]:
for axis in range(3):
for kd in [True, False]:
a = getattr(mx, op)(x, axis, kd)
b = getattr(np, op)(data, axis, keepdims=kd)
a = a.astype(mx.float32)
self.assertEqual(a.tolist(), b.tolist())
for op in ["argmin", "argmax"]:
a = getattr(mx, op)(x, keepdims=True)
b = getattr(np, op)(data, keepdims=True)
a = a.astype(mx.float32)
self.assertEqual(a.tolist(), b.tolist())
a = getattr(mx, op)(x)
b = getattr(np, op)(data)
a = a.astype(mx.float32)
self.assertEqual(a.item(), b)
def test_blas_ops(self):
if mx.default_device() != mx.gpu:
return
def test_blas(shape_x, shape_y):
np.random.seed(42)
with self.subTest(shape_x=shape_x, shape_y=shape_y):
x = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_x)
y = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_y)
np_args = (
x.astype(np.float32),
y.astype(np.float32),
)
op = "matmul"
self.__default_test(op, np_args, atol_np=1e-3, atol_torch=1e-3)
for shape_x, shape_y in [
[(32, 32), (32, 32)],
[(23, 57), (57, 1)],
[(1, 3), (3, 128)],
[(8, 128, 768), (768, 16)],
]:
test_blas(shape_x, shape_y)
if __name__ == "__main__":
unittest.main()

224
tests/creations_tests.cpp Normal file
View File

@ -0,0 +1,224 @@
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test arange") {
// Check type is inferred correclty
{
auto x = arange(10);
CHECK_EQ(x.dtype(), int32);
x = arange(10.0);
CHECK_EQ(x.dtype(), float32);
x = arange(10, float32);
CHECK_EQ(x.dtype(), float32);
x = arange(10.0, int32);
CHECK_EQ(x.dtype(), int32);
x = arange(0, 10);
CHECK_EQ(x.dtype(), int32);
x = arange(0.0, 10.0, int32);
CHECK_EQ(x.dtype(), int32);
x = arange(0.0, 10.0);
CHECK_EQ(x.dtype(), float32);
x = arange(0, 10, float32);
CHECK_EQ(x.dtype(), float32);
x = arange(0, 10, 0.1, float32);
CHECK_EQ(x.dtype(), float32);
x = arange(0.0, 10.0, 0.5, int32);
CHECK_EQ(x.dtype(), int32);
x = arange(10.0, uint32);
CHECK_EQ(x.dtype(), uint32);
x = arange(0.0, 10.0, uint32);
CHECK_EQ(x.dtype(), uint32);
x = arange(0.0, 10.0, 0.5, uint32);
CHECK_EQ(x.dtype(), uint32);
// arange unsupported for bool_
CHECK_THROWS_AS(arange(10, bool_), std::invalid_argument);
}
// Check correct sizes
{
auto x = arange(10);
CHECK_EQ(x.size(), 10);
x = arange(0.0, 10.0, 0.5);
CHECK_EQ(x.size(), 20);
x = arange(0.0, 10.0, 0.45);
CHECK_EQ(x.size(), 23);
x = arange(0, 10, 10);
CHECK_EQ(x.size(), 1);
x = arange(0, 10, 9);
CHECK_EQ(x.size(), 2);
x = arange(0, 10, 100);
CHECK_EQ(x.size(), 1);
x = arange(0, -10, 1);
CHECK_EQ(x.size(), 0);
x = arange(0, -10, -1);
CHECK_EQ(x.size(), 10);
x = arange(0, -10, -10);
CHECK_EQ(x.size(), 1);
}
// Check values
{
auto x = arange(0, 3);
CHECK(array_equal(x, array({0, 1, 2})).item<bool>());
x = arange(0, 3, 2);
CHECK(array_equal(x, array({0, 2})).item<bool>());
x = arange(0, 3, 3);
CHECK(array_equal(x, array({0})).item<bool>());
x = arange(0, -3, 1);
CHECK(array_equal(x, array({})).item<bool>());
x = arange(0, 3, -1);
CHECK(array_equal(x, array({})).item<bool>());
x = arange(0, -3, -1);
CHECK(array_equal(x, array({0, -1, -2})).item<bool>());
x = arange(0.0, 5.0, 0.5, int32);
CHECK(array_equal(x, zeros({10})).item<bool>());
x = arange(0.0, 5.0, 1.5, int32);
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
}
}
TEST_CASE("test astype") {
// Check type conversions
{
auto x = array(1);
auto y = astype(x, float32);
CHECK_EQ(y.dtype(), float32);
CHECK_EQ(y.item<float>(), 1.0f);
y = astype(x, int32);
CHECK_EQ(y.dtype(), int32);
CHECK_EQ(y.item<int>(), 1);
x = array(-3.0f);
y = astype(x, int32);
CHECK_EQ(y.dtype(), int32);
CHECK_EQ(y.item<int>(), -3);
y = astype(x, uint32);
CHECK_EQ(y.dtype(), uint32);
// Use std::copy since the result is platform dependent
uint32_t v;
std::copy(x.data<float>(), x.data<float>() + 1, &v);
CHECK_EQ(y.item<uint32_t>(), v);
}
}
TEST_CASE("test full") {
// Check full works for different types
{
auto x = full({}, 0);
CHECK_EQ(x.dtype(), int32);
CHECK_EQ(x.item<int>(), 0);
x = full({}, 0.0);
CHECK_EQ(x.dtype(), float32);
CHECK_EQ(x.item<float>(), 0);
x = full({}, false);
CHECK_EQ(x.item<bool>(), false);
x = full({}, 0, int32);
CHECK_EQ(x.item<int>(), 0);
x = full({}, 0, float32);
CHECK_EQ(x.item<float>(), 0);
x = full({1, 2}, 2, float32);
CHECK(array_equal(x, array({2.0, 2.0}, {1, 2})).item<bool>());
x = full({2, 1}, 2, float32);
CHECK(array_equal(x, array({2.0, 2.0}, {2, 1})).item<bool>());
x = full({2}, false);
CHECK_EQ(x.dtype(), bool_);
CHECK(array_equal(x, array({false, false})).item<bool>());
x = full({2}, 1.0, bool_);
CHECK_EQ(x.dtype(), bool_);
CHECK(array_equal(x, array({true, true})).item<bool>());
x = full({2}, 1.0, uint32);
CHECK_EQ(x.dtype(), uint32);
CHECK(array_equal(x, array({1, 1})).item<bool>());
CHECK_THROWS_AS(full({2}, array({})), std::invalid_argument);
}
// Check broadcasting works
{
auto x = full({2, 2}, array({3, 4}, {2, 1}));
CHECK(array_equal(x, array({3, 3, 4, 4}, {2, 2})).item<bool>());
x = full({2, 2}, array({3, 4}, {1, 2}));
CHECK(array_equal(x, array({3, 4, 3, 4}, {2, 2})).item<bool>());
}
// Check zeros and ones
{
auto x = zeros({2, 2}, float32);
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
CHECK_EQ(x.ndim(), 2);
CHECK_EQ(x.dtype(), float32);
auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2});
CHECK(array_equal(x, y).item<bool>());
x = ones({2, 2}, float32);
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
CHECK_EQ(x.ndim(), 2);
CHECK_EQ(x.dtype(), float32);
y = array({1.0, 1.0, 1.0, 1.0}, {2, 2});
CHECK(array_equal(x, y).item<bool>());
x = zeros({2, 2}, int32);
y = zeros_like(x);
CHECK_EQ(y.dtype(), int32);
CHECK(array_equal(x, y).item<bool>());
x = ones({2, 2}, int32);
y = ones_like(x);
CHECK_EQ(y.dtype(), int32);
CHECK(array_equal(x, y).item<bool>());
}
// Works for empty shape and empty array
{
array x = ones({}, int32);
CHECK_EQ(x.shape(), std::vector<int>{});
CHECK_EQ(x.item<int>(), 1);
x = full({0}, array({}));
CHECK_EQ(x.shape(), std::vector<int>{0});
CHECK_EQ(x.size(), 0);
CHECK_THROWS_AS(full({}, array({})), std::invalid_argument);
}
}

331
tests/fft_tests.cpp Normal file
View File

@ -0,0 +1,331 @@
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test fft basics") {
auto device = default_device();
set_default_device(Device::cpu);
array x(1.0);
CHECK_THROWS(fft::fft(x));
CHECK_THROWS(fft::ifft(x));
x = array({1.0});
auto y = fft::fft(x);
CHECK_EQ(y.dtype(), complex64);
CHECK_EQ(y.size(), x.size());
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});
y = fft::ifft(x);
CHECK_EQ(y.dtype(), complex64);
CHECK_EQ(y.size(), x.size());
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});
x = array({complex64_t{1.0f, 1.0f}}, complex64);
y = fft::fft(x);
CHECK_EQ(y.size(), x.size());
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 1.0f});
y = fft::ifft(x);
CHECK_EQ(y.dtype(), complex64);
CHECK_EQ(y.size(), x.size());
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 1.0f});
{
x = array({0.0f, 1.0f, 2.0f, 3.0f});
y = fft::fft(x);
std::initializer_list<complex64_t> expected = {
{6.0, 0.0},
{-2.0, 2.0},
{-2.0, 0.0},
{-2.0, -2.0},
};
CHECK_EQ(y.size(), x.size());
CHECK(array_equal(y, array(expected)).item<bool>());
y = fft::ifft(x);
std::initializer_list<complex64_t> expected_inv = {
{1.5, 0.0},
{-0.5, -0.5},
{-0.5, 0.0},
{-0.5, 0.5},
};
CHECK(array_equal(y, array(expected_inv)).item<bool>());
}
{
std::initializer_list<complex64_t> vals = {
{1.0f, 1.0f}, {2.0f, 1.0f}, {1.0f, 2.0f}, {2.0f, 2.0f}};
x = array(vals);
y = fft::fft(x);
std::initializer_list<complex64_t> expected = {
{6.0, 6.0},
{-1.0, -1.0},
{-2.0, 0.0},
{1.0, -1.0},
};
CHECK_EQ(y.size(), x.size());
CHECK(array_equal(y, array(expected)).item<bool>());
CHECK(array_equal(fft::ifft(y), x).item<bool>());
}
// Specify axes
{
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});
std::initializer_list<complex64_t> expected_0 = {
{2.0, 0.0},
{4.0, 0.0},
{-2.0, 0.0},
{-2.0, 0.0},
};
y = fft::fft(x, 0);
CHECK(array_equal(y, array(expected_0, {2, 2})).item<bool>());
CHECK(array_equal(fft::ifft(y, 0), x).item<bool>());
std::initializer_list<complex64_t> expected_1 = {
{1.0, 0.0},
{-1.0, 0.0},
{5.0, 0.0},
{-1.0, 0.0},
};
y = fft::fft(x, 1);
CHECK(array_equal(y, array(expected_1, {2, 2})).item<bool>());
CHECK(array_equal(fft::ifft(y, 1), x).item<bool>());
}
set_default_device(device);
}
TEST_CASE("test real ffts") {
auto device = default_device();
set_default_device(Device::cpu);
auto x = array({1.0});
auto y = fft::rfft(x);
CHECK_EQ(y.dtype(), complex64);
CHECK_EQ(y.size(), x.size());
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});
{
x = array({0.0f, 1.0f, 2.0f, 3.0f});
y = fft::rfft(x);
std::initializer_list<complex64_t> expected = {
{6.0, 0.0}, {-2.0, 2.0}, {-2.0, -0.0}};
CHECK_EQ(y.size(), x.size() / 2 + 1);
CHECK(array_equal(y, array(expected)).item<bool>());
}
x = array(complex64_t{1, 1});
CHECK_THROWS(fft::irfft(x));
x = array({complex64_t{0, 1}, complex64_t{1, 0}});
y = fft::irfft(x);
CHECK_EQ(y.size(), 2);
CHECK_EQ(y.dtype(), float32);
CHECK(array_equal(y, array({0.5f, -0.5f})).item<bool>());
set_default_device(device);
}
TEST_CASE("test fftn") {
auto device = default_device();
set_default_device(Device::cpu);
auto x = zeros({5, 5, 5});
CHECK_THROWS_AS(fft::fftn(x, {}, {0, 3}), std::invalid_argument);
CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument);
CHECK_THROWS_AS(fft::fftn(x, {}, {0, 0}), std::invalid_argument);
CHECK_THROWS_AS(fft::fftn(x, {5, 5, 5}, {0}), std::invalid_argument);
CHECK_THROWS_AS(fft::fftn(x, {0}, {}, {}), std::invalid_argument);
CHECK_THROWS_AS(fft::fftn(x, {1, -1}, {}, {}), std::invalid_argument);
// Test 2D FFT
{
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});
std::initializer_list<complex64_t> expected = {
{6.0, 0.0},
{-2.0, 0.0},
{-4.0, 0.0},
{0.0, 0.0},
};
auto y = fft::fft2(x);
CHECK(array_equal(y, array(expected, {2, 2})).item<bool>());
CHECK(array_equal(fft::ifft2(y), x).item<bool>());
}
// Test 3D FFT
{
x = reshape(arange(8, float32), {2, 2, 2});
std::initializer_list<complex64_t> expected = {
{28.0, 0.0},
{-4.0, 0.0},
{-8.0, 0.0},
{0.0, 0.0},
{-16.0, 0.0},
{0.0, 0.0},
{0.0, 0.0},
{0.0, 0.0},
};
auto y = fft::fftn(x);
CHECK(array_equal(y, array(expected, {2, 2, 2})).item<bool>());
CHECK(array_equal(fft::ifftn(y), x).item<bool>());
x = reshape(arange(20, float32), {5, 4});
y = fft::rfftn(x);
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
y = fft::rfftn(x, {1, 0});
CHECK_EQ(y.shape(), std::vector<int>{3, 4});
x = reshape(arange(20, float32), {5, 4});
y = fft::irfftn(x);
CHECK_EQ(y.shape(), std::vector<int>{5, 6});
y = fft::irfftn(x, {1, 0});
CHECK_EQ(y.shape(), std::vector<int>{8, 4});
}
// Check the types of real ffts
{
x = zeros({5, 5}, float32);
auto y = fft::rfft2(x);
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
CHECK_EQ(y.dtype(), complex64);
y = fft::rfftn(x);
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
CHECK_EQ(y.dtype(), complex64);
x = zeros({5, 5}, complex64);
y = fft::irfft2(x);
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
CHECK_EQ(y.dtype(), float32);
y = fft::irfftn(x);
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
CHECK_EQ(y.dtype(), float32);
}
set_default_device(device);
}
TEST_CASE("test fft with provided shape") {
auto x = ones({5, 5});
auto y = fft::fft(x, 7, 0);
CHECK_EQ(y.shape(), std::vector<int>{7, 5});
y = fft::fft(x, 3, 0);
CHECK_EQ(y.shape(), std::vector<int>{3, 5});
y = fft::fft(x, 7, 1);
CHECK_EQ(y.shape(), std::vector<int>{5, 7});
y = fft::fft(x, 3, 1);
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
y = fft::rfft(x, 7, 0);
CHECK_EQ(y.shape(), std::vector<int>{4, 5});
y = fft::rfft(x, 3, 0);
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
y = fft::rfft(x, 3, 1);
CHECK_EQ(y.shape(), std::vector<int>{5, 2});
}
TEST_CASE("test fft vmap") {
auto device = default_device();
set_default_device(Device::cpu);
auto fft_fn = [](array x) { return fft::fft(x); };
auto x = reshape(arange(8), {2, 4});
auto y = vmap(fft_fn)(x);
CHECK(array_equal(y, fft::fft(x)).item<bool>());
y = vmap(fft_fn, 1, 1)(x);
CHECK(array_equal(y, fft::fft(x, 0)).item<bool>());
auto rfft_fn = [](array x) { return fft::rfft(x); };
y = vmap(rfft_fn)(x);
CHECK(array_equal(y, fft::rfft(x)).item<bool>());
y = vmap(rfft_fn, 1, 1)(x);
CHECK(array_equal(y, fft::rfft(x, 0)).item<bool>());
set_default_device(device);
}
TEST_CASE("test fft grads") {
auto device = default_device();
set_default_device(Device::cpu);
// Regular
auto fft_fn = [](array x) { return fft::fft(x); };
auto cotangent = astype(arange(10), complex64);
auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second;
CHECK(array_equal(fft::fft(cotangent), vjp_out).item<bool>());
auto tangent = astype(arange(10), complex64);
auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second;
CHECK(array_equal(fft::fft(tangent), jvp_out).item<bool>());
// Inverse
auto ifft_fn = [](array x) { return fft::ifft(x); };
vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second;
CHECK(array_equal(fft::ifft(cotangent), vjp_out).item<bool>());
jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second;
CHECK(array_equal(fft::ifft(tangent), jvp_out).item<bool>());
// Real
auto rfft_fn = [](array x) { return fft::rfft(x); };
cotangent = astype(arange(6), complex64);
vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second;
auto expected = astype(fft::fft(cotangent, 10, 0), float32);
CHECK(array_equal(expected, vjp_out).item<bool>());
tangent = astype(arange(10), float32);
jvp_out = jvp(rfft_fn, zeros_like(tangent), tangent).second;
CHECK(array_equal(fft::rfft(tangent), jvp_out).item<bool>());
// Inverse real
auto irfft_fn = [](array x) { return fft::irfft(x); };
cotangent = astype(arange(10), float32);
vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second;
expected = fft::fft(cotangent, 10, 0);
auto o_splits = split(vjp_out, {1, 5});
auto e_splits = split(expected, {1, 5, 6});
CHECK_EQ(e_splits[0].item<complex64_t>(), o_splits[0].item<complex64_t>());
CHECK(array_equal(2 * e_splits[1], o_splits[1]).item<bool>());
CHECK_EQ(e_splits[2].item<complex64_t>(), o_splits[2].item<complex64_t>());
tangent = astype(arange(10), complex64);
jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second;
CHECK(array_equal(fft::irfft(tangent), jvp_out).item<bool>());
// Check ND vjps run properly
vjp_out = vjp([](array x) { return fft::fftn(x); },
astype(zeros({5, 5}), complex64),
astype(zeros({5, 5}), complex64))
.second;
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
vjp_out = vjp([](array x) { return fft::ifftn(x); },
astype(zeros({5, 5}), complex64),
astype(zeros({5, 5}), complex64))
.second;
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
vjp_out = vjp([](array x) { return fft::rfftn(x); },
zeros({5, 9}),
astype(zeros({5, 5}), complex64))
.second;
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 9});
vjp_out = vjp([](array x) { return fft::irfftn(x); },
astype(zeros({5, 5}), complex64),
zeros({5, 8}))
.second;
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
set_default_device(device);
}

View File

@ -0,0 +1,30 @@
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test simplify scalars") {
auto a = array({-1.0f, 2.0f});
auto b = maximum(a, array(0.0f));
auto c = maximum(-a, array(0.0f));
auto d = b + c;
simplify({d});
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
}
TEST_CASE("test simplify") {
auto a = array({1.0f, 2.0f});
auto b = exp(a) + exp(a);
simplify(b);
eval(b);
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
}
TEST_CASE("test no simplify") {
auto a = array({1.0f, 2.0f});
auto b = cos(a) + sin(a);
simplify(b);
eval(b);
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
}