mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
angelos's commit files
This commit is contained in:
parent
8ca7f9e8e9
commit
d1f86272a2
75
.gitignore
vendored
Normal file
75
.gitignore
vendored
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Metal libraries
|
||||||
|
*.metallib
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
python/mlx/share
|
||||||
|
python/mlx/include
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# vim
|
||||||
|
*.swp
|
||||||
|
|
||||||
|
# Ignore build dir
|
||||||
|
build/
|
||||||
|
|
||||||
|
# Prerequisites
|
||||||
|
*.d
|
||||||
|
|
||||||
|
# Compiled Object files
|
||||||
|
*.slo
|
||||||
|
*.lo
|
||||||
|
*.o
|
||||||
|
*.obj
|
||||||
|
|
||||||
|
# Precompiled Headers
|
||||||
|
*.gch
|
||||||
|
*.pch
|
||||||
|
|
||||||
|
# Compiled Dynamic libraries
|
||||||
|
*.so
|
||||||
|
*.dylib
|
||||||
|
*.dll
|
||||||
|
|
||||||
|
# Fortran module files
|
||||||
|
*.mod
|
||||||
|
*.smod
|
||||||
|
|
||||||
|
# Compiled Static libraries
|
||||||
|
*.lai
|
||||||
|
*.la
|
||||||
|
*.a
|
||||||
|
*.lib
|
||||||
|
|
||||||
|
# Executables
|
||||||
|
*.exe
|
||||||
|
*.out
|
||||||
|
*.app
|
||||||
|
|
||||||
|
# VSCode
|
||||||
|
.vscode/
|
||||||
|
.DS_Store
|
9
.pre-commit-config.yaml
Normal file
9
.pre-commit-config.yaml
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
|
rev: v14.0.6
|
||||||
|
hooks:
|
||||||
|
- id: clang-format
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 22.10.0
|
||||||
|
hooks:
|
||||||
|
- id: black
|
197
CMakeLists.txt
Normal file
197
CMakeLists.txt
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.24)
|
||||||
|
|
||||||
|
project(mlx LANGUAGES CXX)
|
||||||
|
|
||||||
|
# ----------------------------- Setup -----------------------------
|
||||||
|
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
|
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||||
|
|
||||||
|
# ----------------------------- Configuration -----------------------------
|
||||||
|
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
||||||
|
option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
||||||
|
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||||
|
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||||
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
|
||||||
|
if(NOT MLX_VERSION)
|
||||||
|
set(MLX_VERSION 0.0.1)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# ----------------------------- Lib -----------------------------
|
||||||
|
|
||||||
|
include(FetchContent)
|
||||||
|
# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
|
||||||
|
cmake_policy(SET CMP0135 NEW)
|
||||||
|
|
||||||
|
add_library(mlx)
|
||||||
|
|
||||||
|
if (MLX_BUILD_METAL)
|
||||||
|
find_library(METAL_LIB Metal)
|
||||||
|
find_library(FOUNDATION_LIB Foundation)
|
||||||
|
find_library(QUARTZ_LIB QuartzCore)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||||
|
message(STATUS "Metal not found. Unable to build GPU")
|
||||||
|
elseif (MLX_BUILD_METAL)
|
||||||
|
message(STATUS "Building METAL sources")
|
||||||
|
add_compile_definitions(_METAL_)
|
||||||
|
|
||||||
|
execute_process(COMMAND zsh "-c" "/usr/bin/sw_vers | cut -f2- -d: | sed -n 2p | grep -Eo '[0-9]+.[0-9]+'"
|
||||||
|
OUTPUT_VARIABLE MACOS_VERSION)
|
||||||
|
|
||||||
|
message(STATUS "Detected macOS version ${MACOS_VERSION}")
|
||||||
|
if (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||||
|
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||||
|
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
|
||||||
|
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
|
||||||
|
else()
|
||||||
|
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13_iOS16.zip)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
FetchContent_Declare(
|
||||||
|
metal_cpp
|
||||||
|
URL ${METAL_CPP_URL}
|
||||||
|
)
|
||||||
|
|
||||||
|
FetchContent_MakeAvailable(metal_cpp)
|
||||||
|
target_include_directories(
|
||||||
|
mlx PUBLIC
|
||||||
|
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||||
|
$<INSTALL_INTERFACE:include/metal_cpp>
|
||||||
|
)
|
||||||
|
target_link_libraries(
|
||||||
|
mlx
|
||||||
|
${METAL_LIB}
|
||||||
|
${FOUNDATION_LIB}
|
||||||
|
${QUARTZ_LIB})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||||
|
if (ACCELERATE_LIBRARY)
|
||||||
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
|
set(MLX_BUILD_ACCELERATE ON)
|
||||||
|
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
||||||
|
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||||
|
else()
|
||||||
|
message(STATUS "Accelerate not found, using default backend.")
|
||||||
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
|
#set(BLA_VENDOR Generic)
|
||||||
|
find_package(BLAS REQUIRED)
|
||||||
|
if (NOT BLAS_FOUND)
|
||||||
|
message(FATAL_ERROR "Must have BLAS installed")
|
||||||
|
endif()
|
||||||
|
# TODO find a cleaner way to do this
|
||||||
|
find_path(BLAS_INCLUDE_DIRS cblas.h
|
||||||
|
/usr/include
|
||||||
|
/usr/local/include
|
||||||
|
$ENV{BLAS_HOME}/include)
|
||||||
|
message(STATUS ${BLAS_LIBRARIES})
|
||||||
|
message(STATUS ${BLAS_INCLUDE_DIRS})
|
||||||
|
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||||
|
|
||||||
|
target_include_directories(
|
||||||
|
mlx
|
||||||
|
PUBLIC
|
||||||
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
|
$<INSTALL_INTERFACE:include>
|
||||||
|
)
|
||||||
|
|
||||||
|
if (MLX_BUILD_PYTHON_BINDINGS)
|
||||||
|
message(STATUS "Building Python bindings.")
|
||||||
|
find_package(Python COMPONENTS Interpreter Development)
|
||||||
|
find_package(pybind11 CONFIG REQUIRED)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (MLX_BUILD_TESTS)
|
||||||
|
include(CTest)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (MLX_BUILD_EXAMPLES)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (MLX_BUILD_BENCHMARKS)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# ----------------------------- Installation -----------------------------
|
||||||
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
|
# Install library
|
||||||
|
install(
|
||||||
|
TARGETS mlx
|
||||||
|
EXPORT MLXTargets
|
||||||
|
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
|
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||||
|
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Install headers
|
||||||
|
install(
|
||||||
|
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
||||||
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||||
|
COMPONENT headers
|
||||||
|
FILES_MATCHING PATTERN "*.h"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Install metal dependencies
|
||||||
|
if (MLX_BUILD_METAL)
|
||||||
|
|
||||||
|
# Install metal cpp
|
||||||
|
install(
|
||||||
|
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
||||||
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
||||||
|
COMPONENT metal_cpp_source
|
||||||
|
)
|
||||||
|
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Install cmake config
|
||||||
|
set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake)
|
||||||
|
set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake)
|
||||||
|
set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
|
||||||
|
|
||||||
|
install(
|
||||||
|
EXPORT MLXTargets
|
||||||
|
FILE MLXTargets.cmake
|
||||||
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||||
|
)
|
||||||
|
|
||||||
|
include(CMakePackageConfigHelpers)
|
||||||
|
|
||||||
|
write_basic_package_version_file(
|
||||||
|
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||||
|
COMPATIBILITY SameMajorVersion
|
||||||
|
VERSION ${MLX_VERSION}
|
||||||
|
)
|
||||||
|
|
||||||
|
configure_package_config_file(
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
|
||||||
|
${MLX_CMAKE_BUILD_CONFIG}
|
||||||
|
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||||
|
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
||||||
|
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
|
||||||
|
)
|
||||||
|
|
||||||
|
install(
|
||||||
|
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||||
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||||
|
)
|
||||||
|
|
||||||
|
install(
|
||||||
|
DIRECTORY ${CMAKE_MODULE_PATH}/
|
||||||
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||||
|
)
|
38
benchmarks/cpp/time_utils.h
Normal file
38
benchmarks/cpp/time_utils.h
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
#define milliseconds(x) \
|
||||||
|
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e6)
|
||||||
|
#define time_now() std::chrono::high_resolution_clock::now()
|
||||||
|
|
||||||
|
#define TIME(FUNC, ...) \
|
||||||
|
std::cout << "Timing " << #FUNC << " ... " << std::flush \
|
||||||
|
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||||
|
<< std::endl;
|
||||||
|
|
||||||
|
#define TIMEM(MSG, FUNC, ...) \
|
||||||
|
std::cout << "Timing " \
|
||||||
|
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
|
||||||
|
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||||
|
<< std::endl;
|
||||||
|
|
||||||
|
template <typename F, typename... Args>
|
||||||
|
double time_fn(F fn, Args... args) {
|
||||||
|
// warmup
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
eval(fn(std::forward<Args>(args)...));
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_iters = 100;
|
||||||
|
auto start = time_now();
|
||||||
|
for (int i = 0; i < num_iters; i++) {
|
||||||
|
eval(fn(std::forward<Args>(args)...));
|
||||||
|
}
|
||||||
|
auto end = time_now();
|
||||||
|
return milliseconds(end - start) / static_cast<double>(num_iters);
|
||||||
|
}
|
60
benchmarks/python/batch_matmul_bench.py
Normal file
60
benchmarks/python/batch_matmul_bench.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import argparse
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
B = 8
|
||||||
|
T = 1024
|
||||||
|
D = 512
|
||||||
|
|
||||||
|
|
||||||
|
def time_batch_matmul():
|
||||||
|
mx.random.seed(3)
|
||||||
|
a = mx.random.uniform(shape=(B, T, D))
|
||||||
|
b = mx.random.uniform(shape=(D, D))
|
||||||
|
c = mx.random.uniform(shape=(B, T, D))
|
||||||
|
mx.eval(a, b, c)
|
||||||
|
|
||||||
|
time_fn(mx.matmul, a, b)
|
||||||
|
|
||||||
|
def batch_vjp_first():
|
||||||
|
return mx.vjp(mx.matmul, [a, b], [c])[1][0]
|
||||||
|
|
||||||
|
time_fn(batch_vjp_first)
|
||||||
|
|
||||||
|
def batch_vjp_second():
|
||||||
|
return mx.vjp(mx.matmul, [a, b], [c])[1][1]
|
||||||
|
|
||||||
|
time_fn(batch_vjp_second)
|
||||||
|
|
||||||
|
|
||||||
|
def time_unbatch_matmul(key):
|
||||||
|
mx.random.seed(3)
|
||||||
|
a = mx.random.uniform(shape=(B * T, D))
|
||||||
|
b = mx.random.uniform(shape=(D, D))
|
||||||
|
c = mx.random.uniform(shape=(B * T, D))
|
||||||
|
mx.eval(a, b, c)
|
||||||
|
time_fn(mx.matmul, a, b)
|
||||||
|
|
||||||
|
def unbatch_vjp_first():
|
||||||
|
return mx.matmul(c, mx.transpose(b))
|
||||||
|
|
||||||
|
time_fn(unbatch_vjp_first)
|
||||||
|
|
||||||
|
def unbatch_vjp_second():
|
||||||
|
return mx.matmul(mx.transpose(a), c)
|
||||||
|
|
||||||
|
time_fn(unbatch_vjp_second)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||||
|
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.gpu:
|
||||||
|
mx.set_default_device(mx.gpu)
|
||||||
|
else:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
time_batch_matmul()
|
||||||
|
time_unbatch_matmul()
|
20
benchmarks/python/time_utils.py
Normal file
20
benchmarks/python/time_utils.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def time_fn(fn, *args, **kwargs):
|
||||||
|
print(f"Timing {fn.__name__} ...", end=" ")
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for _ in range(5):
|
||||||
|
mx.eval(fn(*args, **kwargs))
|
||||||
|
|
||||||
|
num_iters = 100
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(num_iters):
|
||||||
|
x = mx.eval(fn(*args, **kwargs))
|
||||||
|
toc = time.perf_counter()
|
||||||
|
|
||||||
|
msec = 1e3 * (toc - tic) / num_iters
|
||||||
|
print(f"{msec:.5f} msec")
|
2
docs/.clang-format
Normal file
2
docs/.clang-format
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
DisableFormat: true
|
||||||
|
SortIncludes: Never
|
18
docs/Makefile
Normal file
18
docs/Makefile
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Minimal makefile for Sphinx documentation
|
||||||
|
|
||||||
|
# You can set these variables from the command line.
|
||||||
|
SPHINXOPTS =
|
||||||
|
SPHINXBUILD = sphinx-build
|
||||||
|
SOURCEDIR = src
|
||||||
|
BUILDDIR = build
|
||||||
|
|
||||||
|
# Put it first so that "make" without argument is like "make help".
|
||||||
|
help:
|
||||||
|
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||||
|
|
||||||
|
.PHONY: help Makefile
|
||||||
|
|
||||||
|
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||||
|
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||||
|
%: Makefile
|
||||||
|
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
131
docs/src/examples/mlp.rst
Normal file
131
docs/src/examples/mlp.rst
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
.. _mlp:
|
||||||
|
|
||||||
|
Multi-Layer Perceptron
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
In this example we'll learn to use ``mlx.nn`` by implementing a simple
|
||||||
|
multi-layer perceptron to classify MNIST.
|
||||||
|
|
||||||
|
As a first step import the MLX packages we need:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
The model is defined as the ``MLP`` class which inherits from
|
||||||
|
:class:`mlx.nn.Module`. We follow the standard idiom to make a new module:
|
||||||
|
|
||||||
|
1. Define an ``__init__`` where the parameters and/or submodules are setup. See
|
||||||
|
the :ref:`Module class docs<module_class>` for more information on how
|
||||||
|
:class:`mlx.nn.Module` registers parameters.
|
||||||
|
2. Define a ``__call__`` where the computation is implemented.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||||
|
self.layers = [
|
||||||
|
nn.Linear(idim, odim)
|
||||||
|
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for l in self.layers[:-1]:
|
||||||
|
x = mx.maximum(l(x), 0.0)
|
||||||
|
return self.layers[-1](x)
|
||||||
|
|
||||||
|
|
||||||
|
We define the loss function which takes the mean of the per-example cross
|
||||||
|
entropy loss. The ``mlx.nn.losses`` sub-package has implementations of some
|
||||||
|
commonly used loss functions.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def loss_fn(model, X, y):
|
||||||
|
return mx.mean(nn.losses.cross_entropy(model(X), y))
|
||||||
|
|
||||||
|
We also need a function to compute the accuracy of the model on the validation
|
||||||
|
set:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def eval_fn(model, X, y):
|
||||||
|
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||||
|
|
||||||
|
Next, setup the problem parameters and load the data:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
num_layers = 2
|
||||||
|
hidden_dim = 32
|
||||||
|
num_classes = 10
|
||||||
|
batch_size = 256
|
||||||
|
num_epochs = 10
|
||||||
|
learning_rate = 1e-1
|
||||||
|
|
||||||
|
# Load the data
|
||||||
|
import mnist
|
||||||
|
train_images, train_labels, test_images, test_labels = map(
|
||||||
|
mx.array, mnist.mnist()
|
||||||
|
)
|
||||||
|
|
||||||
|
Since we're using SGD, we need an iterator which shuffles and constructs
|
||||||
|
minibatches of examples in the training set:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def batch_iterate(batch_size, X, y):
|
||||||
|
perm = mx.array(np.random.permutation(y.size))
|
||||||
|
for s in range(0, y.size, batch_size):
|
||||||
|
ids = perm[s : s + batch_size]
|
||||||
|
yield X[ids], y[ids]
|
||||||
|
|
||||||
|
|
||||||
|
Finally, we put it all together by instantiating the model, the
|
||||||
|
:class:`mlx.optimizers.SGD` optimizer, and running the training loop:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
# Get a function which gives the loss and gradient of the
|
||||||
|
# loss with respect to the model's trainable parameters
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
|
||||||
|
# Instantiate the optimizer
|
||||||
|
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||||
|
|
||||||
|
for e in range(num_epochs):
|
||||||
|
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||||
|
loss, grads = loss_and_grad_fn(model, X, y)
|
||||||
|
|
||||||
|
# Update the optimizer state and model parameters
|
||||||
|
# in a single call
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
|
# Force a graph evaluation
|
||||||
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
|
||||||
|
accuracy = eval_fn(model, test_images, test_labels)
|
||||||
|
print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The :func:`mlx.nn.value_and_grad` function is a convenience function to get
|
||||||
|
the gradient of a loss with respect to the trainable parameters of a model.
|
||||||
|
This should not be confused with :func:`mlx.core.value_and_grad`.
|
||||||
|
|
||||||
|
The model should train to a decent accuracy (about 95%) after just a few passes
|
||||||
|
over the training set. The `full example <https://github.com/ml-explore/mlx-examples/tree/main/mlp>`_
|
||||||
|
is available in the MLX GitHub repo.
|
45
docs/src/python/array.rst
Normal file
45
docs/src/python/array.rst
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
.. _array:
|
||||||
|
|
||||||
|
Array
|
||||||
|
=====
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
array
|
||||||
|
array.astype
|
||||||
|
array.item
|
||||||
|
array.tolist
|
||||||
|
array.dtype
|
||||||
|
array.ndim
|
||||||
|
array.shape
|
||||||
|
array.size
|
||||||
|
Dtype
|
||||||
|
array.abs
|
||||||
|
array.all
|
||||||
|
array.any
|
||||||
|
array.argmax
|
||||||
|
array.argmin
|
||||||
|
array.cos
|
||||||
|
array.dtype
|
||||||
|
array.exp
|
||||||
|
array.log
|
||||||
|
array.log1p
|
||||||
|
array.logsumexp
|
||||||
|
array.max
|
||||||
|
array.mean
|
||||||
|
array.min
|
||||||
|
array.prod
|
||||||
|
array.reciprocal
|
||||||
|
array.reshape
|
||||||
|
array.rsqrt
|
||||||
|
array.sin
|
||||||
|
array.split
|
||||||
|
array.sqrt
|
||||||
|
array.square
|
||||||
|
array.sum
|
||||||
|
array.transpose
|
||||||
|
array.T
|
||||||
|
array.var
|
94
docs/src/python/ops.rst
Normal file
94
docs/src/python/ops.rst
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
.. _ops:
|
||||||
|
|
||||||
|
Operations
|
||||||
|
==========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
abs
|
||||||
|
add
|
||||||
|
all
|
||||||
|
allclose
|
||||||
|
any
|
||||||
|
arange
|
||||||
|
arccos
|
||||||
|
arccosh
|
||||||
|
arcsin
|
||||||
|
arcsinh
|
||||||
|
arctan
|
||||||
|
arctanh
|
||||||
|
argmax
|
||||||
|
argmin
|
||||||
|
argpartition
|
||||||
|
argsort
|
||||||
|
array_equal
|
||||||
|
broadcast_to
|
||||||
|
concatenate
|
||||||
|
convolve
|
||||||
|
conv1d
|
||||||
|
conv2d
|
||||||
|
cos
|
||||||
|
cosh
|
||||||
|
divide
|
||||||
|
equal
|
||||||
|
erf
|
||||||
|
erfinv
|
||||||
|
exp
|
||||||
|
expand_dims
|
||||||
|
full
|
||||||
|
greater
|
||||||
|
greater_equal
|
||||||
|
less
|
||||||
|
less_equal
|
||||||
|
load
|
||||||
|
log
|
||||||
|
log2
|
||||||
|
log10
|
||||||
|
log1p
|
||||||
|
logaddexp
|
||||||
|
logical_not
|
||||||
|
logsumexp
|
||||||
|
matmul
|
||||||
|
max
|
||||||
|
maximum
|
||||||
|
mean
|
||||||
|
min
|
||||||
|
minimum
|
||||||
|
multiply
|
||||||
|
negative
|
||||||
|
ones
|
||||||
|
ones_like
|
||||||
|
partition
|
||||||
|
pad
|
||||||
|
prod
|
||||||
|
reciprocal
|
||||||
|
reshape
|
||||||
|
rsqrt
|
||||||
|
save
|
||||||
|
savez
|
||||||
|
savez_compressed
|
||||||
|
sigmoid
|
||||||
|
sign
|
||||||
|
sin
|
||||||
|
sinh
|
||||||
|
softmax
|
||||||
|
sort
|
||||||
|
split
|
||||||
|
sqrt
|
||||||
|
square
|
||||||
|
squeeze
|
||||||
|
stop_gradient
|
||||||
|
subtract
|
||||||
|
sum
|
||||||
|
take
|
||||||
|
take_along_axis
|
||||||
|
tan
|
||||||
|
tanh
|
||||||
|
transpose
|
||||||
|
var
|
||||||
|
where
|
||||||
|
zeros
|
||||||
|
zeros_like
|
18
examples/cpp/timer.h
Normal file
18
examples/cpp/timer.h
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
namespace timer {
|
||||||
|
|
||||||
|
using namespace std::chrono;
|
||||||
|
|
||||||
|
template <typename R, typename P>
|
||||||
|
inline double seconds(duration<R, P> x) {
|
||||||
|
return duration_cast<nanoseconds>(x).count() / 1e9;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline auto time() {
|
||||||
|
return high_resolution_clock::now();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace timer
|
359
examples/extensions/axpby/axpby.cpp
Normal file
359
examples/extensions/axpby/axpby.cpp
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#include "axpby/axpby.h"
|
||||||
|
|
||||||
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
|
#include <vecLib/cblas_new.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef _METAL_
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Operation Implementation
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Scale and sum two vectors elementwise
|
||||||
|
* z = alpha * x + beta * y
|
||||||
|
*
|
||||||
|
* Follow numpy style broadcasting between x and y
|
||||||
|
* Inputs are upcasted to floats if needed
|
||||||
|
**/
|
||||||
|
array axpby(
|
||||||
|
const array& x, // Input array x
|
||||||
|
const array& y, // Input array y
|
||||||
|
const float alpha, // Scaling factor for x
|
||||||
|
const float beta, // Scaling factor for y
|
||||||
|
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||||
|
) {
|
||||||
|
// Promote dtypes between x and y as needed
|
||||||
|
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||||
|
|
||||||
|
// Upcast to float32 for non-floating point inputs x and y
|
||||||
|
auto out_dtype = is_floating_point(promoted_dtype)
|
||||||
|
? promoted_dtype
|
||||||
|
: promote_types(promoted_dtype, float32);
|
||||||
|
|
||||||
|
// Cast x and y up to the determined dtype (on the same stream s)
|
||||||
|
auto x_casted = astype(x, out_dtype, s);
|
||||||
|
auto y_casted = astype(y, out_dtype, s);
|
||||||
|
|
||||||
|
// Broadcast the shapes of x and y (on the same stream s)
|
||||||
|
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||||
|
auto out_shape = broadcasted_inputs[0].shape();
|
||||||
|
|
||||||
|
// Construct the array as the output of the Axpby primitive
|
||||||
|
// with the broadcasted and upcasted arrays as inputs
|
||||||
|
return array(
|
||||||
|
/* const std::vector<int>& shape = */ out_shape,
|
||||||
|
/* Dtype dtype = */ out_dtype,
|
||||||
|
/* std::unique_ptr<Primitive> primitive = */
|
||||||
|
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
||||||
|
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Primitive Common Backend Implementation
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void axpby_impl(
|
||||||
|
const array& x,
|
||||||
|
const array& y,
|
||||||
|
array& out,
|
||||||
|
float alpha_,
|
||||||
|
float beta_) {
|
||||||
|
// We only allocate memory when we are ready to fill the output
|
||||||
|
// malloc_or_wait synchronously allocates available memory
|
||||||
|
// There may be a wait executed here if the allocation is requested
|
||||||
|
// under memory-pressured conditions
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
// Collect input and output data pointers
|
||||||
|
const T* x_ptr = x.data<T>();
|
||||||
|
const T* y_ptr = y.data<T>();
|
||||||
|
T* out_ptr = out.data<T>();
|
||||||
|
|
||||||
|
// Cast alpha and beta to the relevant types
|
||||||
|
T alpha = static_cast<T>(alpha_);
|
||||||
|
T beta = static_cast<T>(beta_);
|
||||||
|
|
||||||
|
// Do the elementwise operation for each output
|
||||||
|
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||||
|
// Map linear indices to offsets in x and y
|
||||||
|
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||||
|
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||||
|
|
||||||
|
// We allocate the output to be contiguous and regularly strided
|
||||||
|
// (defaults to row major) and hence it doesn't need additonal mapping
|
||||||
|
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Fall back implementation for evaluation on CPU */
|
||||||
|
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
// Check the inputs (registered in the op while contructing the out array)
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto& y = inputs[1];
|
||||||
|
|
||||||
|
// Dispatch to the correct dtype
|
||||||
|
if (out.dtype() == float32) {
|
||||||
|
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||||
|
} else if (out.dtype() == float16) {
|
||||||
|
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||||
|
} else if (out.dtype() == bfloat16) {
|
||||||
|
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||||
|
} else if (out.dtype() == complex64) {
|
||||||
|
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Axpby is only supported for floating point types.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Primitive Accelerate Backend Implementation
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void axpby_impl_accelerate(
|
||||||
|
const array& x,
|
||||||
|
const array& y,
|
||||||
|
array& out,
|
||||||
|
float alpha_,
|
||||||
|
float beta_) {
|
||||||
|
// Accelerate library provides catlas_saxpby which does
|
||||||
|
// Y = (alpha * X) + (beta * Y) in place
|
||||||
|
// To use it, we first copy the data in y over to the output array
|
||||||
|
|
||||||
|
// This specialization requires both x and y be contiguous in the same mode
|
||||||
|
// i.e: corresponding linear indices in both point to corresponding elements
|
||||||
|
// The data in the output array is allocated to match the strides in y
|
||||||
|
// such that x, y, and out are contiguous in the same mode and
|
||||||
|
// no transposition is needed
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
||||||
|
y.data_size(),
|
||||||
|
y.strides(),
|
||||||
|
y.flags());
|
||||||
|
|
||||||
|
// We then copy over the elements using the contiguous vector specialization
|
||||||
|
copy_inplace(y, out, CopyType::Vector);
|
||||||
|
|
||||||
|
// Get x and y pointers for catlas_saxpby
|
||||||
|
const T* x_ptr = x.data<T>();
|
||||||
|
T* y_ptr = out.data<T>();
|
||||||
|
|
||||||
|
T alpha = static_cast<T>(alpha_);
|
||||||
|
T beta = static_cast<T>(beta_);
|
||||||
|
|
||||||
|
// Call the inplace accelerate operator
|
||||||
|
catlas_saxpby(
|
||||||
|
/* N = */ out.size(),
|
||||||
|
/* ALPHA = */ alpha,
|
||||||
|
/* X = */ x_ptr,
|
||||||
|
/* INCX = */ 1,
|
||||||
|
/* BETA = */ beta,
|
||||||
|
/* Y = */ y_ptr,
|
||||||
|
/* INCY = */ 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Evaluate primitive on CPU using accelerate specializations */
|
||||||
|
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto& y = inputs[1];
|
||||||
|
|
||||||
|
// Accelerate specialization for contiguous single precision float arrays
|
||||||
|
if (out.dtype() == float32 &&
|
||||||
|
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||||
|
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||||
|
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to common backend if specializations are not available
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else // Accelerate not avaliable
|
||||||
|
|
||||||
|
/** Evaluate primitive on CPU falling back to common backend */
|
||||||
|
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Primitive Metal Backend Implementation
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#ifdef _METAL_
|
||||||
|
|
||||||
|
/** Evaluate primitive on GPU */
|
||||||
|
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
// Prepare inputs
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto& y = inputs[1];
|
||||||
|
|
||||||
|
// Each primitive carries the stream it should execute on
|
||||||
|
// and each stream carries its device identifiers
|
||||||
|
auto& s = stream();
|
||||||
|
// We get the needed metal device using the stream
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
// Prepare to specialize based on contiguity
|
||||||
|
bool contiguous_kernel =
|
||||||
|
(x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||||
|
(x.flags().col_contiguous && y.flags().col_contiguous);
|
||||||
|
|
||||||
|
// Allocate output memory with strides based on specialization
|
||||||
|
if (contiguous_kernel) {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||||
|
x.data_size(),
|
||||||
|
x.strides(),
|
||||||
|
x.flags());
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve name of kernel (corresponds to axpby.metal)
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "axpby_";
|
||||||
|
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||||
|
kname << type_to_name(out);
|
||||||
|
|
||||||
|
// Make sure the metal library is available and look for it
|
||||||
|
// in the same folder as this executable if needed
|
||||||
|
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||||
|
|
||||||
|
// Make a kernel from this metal library
|
||||||
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
|
||||||
|
// Prepare to encode kernel
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// Kernel parameters are registered with buffer indices corresponding to
|
||||||
|
// those in the kernel decelaration at axpby.metal
|
||||||
|
int ndim = out.ndim();
|
||||||
|
size_t nelem = out.size();
|
||||||
|
|
||||||
|
// Encode input arrays to kernel
|
||||||
|
set_array_buffer(compute_encoder, x, 0);
|
||||||
|
set_array_buffer(compute_encoder, y, 1);
|
||||||
|
|
||||||
|
// Encode output arrays to kernel
|
||||||
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
|
||||||
|
// Encode alpha and beta
|
||||||
|
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||||
|
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||||
|
|
||||||
|
// Encode shape, strides and ndim if needed
|
||||||
|
if (!contiguous_kernel) {
|
||||||
|
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||||
|
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// We launch 1 thread for each input and make sure that the number of
|
||||||
|
// threads in any given threadgroup is not higher than the max allowed
|
||||||
|
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
|
||||||
|
|
||||||
|
// Fix the 3D size of each threadgroup (in terms of threads)
|
||||||
|
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
|
||||||
|
|
||||||
|
// Fix the 3D size of the launch grid (in terms of threads)
|
||||||
|
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||||
|
|
||||||
|
// Launch the grid with the given number of threads divded among
|
||||||
|
// the given threadgroups
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else // Metal is not available
|
||||||
|
|
||||||
|
/** Fail evaluation on GPU */
|
||||||
|
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
throw std::runtime_error("Axpby has no GPU implementation.");
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Primitive Transforms
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/** The Jacobian-vector product. */
|
||||||
|
array Axpby::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
// Forward mode diff that pushes along the tangents
|
||||||
|
// The jvp transform on the the primitive can built with ops
|
||||||
|
// that are scheduled on the same stream as the primtive
|
||||||
|
|
||||||
|
// If argnums = {0}, we only push along x in which case the
|
||||||
|
// jvp is just the tangent scaled by alpha
|
||||||
|
// Similarly, if argnums = {1}, the jvp is just the tangent
|
||||||
|
// scaled by beta
|
||||||
|
if (argnums.size() > 1) {
|
||||||
|
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||||
|
auto scale_arr = array(scale, tangents[0].dtype());
|
||||||
|
return multiply(scale_arr, tangents[0], stream());
|
||||||
|
}
|
||||||
|
// If, argnums = {0, 1}, we take contributions from both
|
||||||
|
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||||
|
else {
|
||||||
|
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The vector-Jacobian product. */
|
||||||
|
std::vector<array> Axpby::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const array& cotan,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
// Reverse mode diff
|
||||||
|
std::vector<array> vjps;
|
||||||
|
for (auto arg : argnums) {
|
||||||
|
auto scale = arg == 0 ? alpha_ : beta_;
|
||||||
|
auto scale_arr = array(scale, cotan.dtype());
|
||||||
|
vjps.push_back(multiply(scale_arr, cotan, stream()));
|
||||||
|
}
|
||||||
|
return vjps;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Vectorize primitve along given axis */
|
||||||
|
std::pair<array, int> Axpby::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Equivalence check **/
|
||||||
|
bool Axpby::is_equivalent(const Primitive& other) const {
|
||||||
|
const Axpby& r_other = static_cast<const Axpby&>(other);
|
||||||
|
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
39
examples/extensions/bindings.cpp
Normal file
39
examples/extensions/bindings.cpp
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#include "axpby/axpby.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
using namespace py::literals;
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||||
|
m.doc() = "Sample C++ and metal extensions for MLX";
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"axpby",
|
||||||
|
&axpby,
|
||||||
|
"x"_a,
|
||||||
|
"y"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"alpha"_a,
|
||||||
|
"beta"_a,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = py::none(),
|
||||||
|
R"pbdoc(
|
||||||
|
Scale and sum two vectors elementwise
|
||||||
|
``z = alpha * x + beta * y``
|
||||||
|
|
||||||
|
Follows numpy style broadcasting between ``x`` and ``y``
|
||||||
|
Inputs are upcasted to floats if needed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (array): Input array.
|
||||||
|
y (array): Input array.
|
||||||
|
alpha (float): Scaling factor for ``x``.
|
||||||
|
beta (float): Scaling factor for ``y``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: ``alpha * x + beta * y``
|
||||||
|
)pbdoc");
|
||||||
|
}
|
36
mlx/CMakeLists.txt
Normal file
36
mlx/CMakeLists.txt
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||||
|
)
|
||||||
|
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||||
|
|
||||||
|
if (MLX_BUILD_ACCELERATE)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||||
|
else()
|
||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (MLX_BUILD_METAL)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||||
|
else()
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||||
|
endif()
|
143
mlx/array.cpp
Normal file
143
mlx/array.cpp
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/transforms.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
|
||||||
|
std::vector<size_t> strides(shape.size());
|
||||||
|
size_t cum_prod = 1;
|
||||||
|
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||||
|
strides[i] = cum_prod;
|
||||||
|
cum_prod *= shape[i];
|
||||||
|
}
|
||||||
|
return {cum_prod, strides};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
||||||
|
auto cval = static_cast<complex64_t>(val);
|
||||||
|
init(&cval);
|
||||||
|
}
|
||||||
|
|
||||||
|
array::array(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
std::unique_ptr<Primitive> primitive,
|
||||||
|
const std::vector<array>& inputs)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
|
shape,
|
||||||
|
dtype,
|
||||||
|
std::move(primitive),
|
||||||
|
inputs)) {}
|
||||||
|
|
||||||
|
array::array(std::initializer_list<float> data)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
|
std::vector<int>{static_cast<int>(data.size())},
|
||||||
|
float32)) {
|
||||||
|
init(data.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Build an array from a shared buffer */
|
||||||
|
array::array(
|
||||||
|
allocator::Buffer data,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
deleter_t deleter)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
|
||||||
|
set_data(data, deleter);
|
||||||
|
}
|
||||||
|
|
||||||
|
void array::detach() {
|
||||||
|
array_desc_->inputs.clear();
|
||||||
|
array_desc_->primitive = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void array::eval(bool retain_graph /* = false */) {
|
||||||
|
mlx::core::eval({*this}, retain_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||||
|
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||||
|
array_desc_->data_ptr = buffer.raw_ptr();
|
||||||
|
array_desc_->data_size = size();
|
||||||
|
array_desc_->flags.contiguous = true;
|
||||||
|
array_desc_->flags.row_contiguous = true;
|
||||||
|
auto max_dim = std::max_element(shape().begin(), shape().end());
|
||||||
|
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
void array::set_data(
|
||||||
|
allocator::Buffer buffer,
|
||||||
|
size_t data_size,
|
||||||
|
std::vector<size_t> strides,
|
||||||
|
Flags flags,
|
||||||
|
deleter_t d) {
|
||||||
|
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||||
|
array_desc_->data_ptr = buffer.raw_ptr();
|
||||||
|
array_desc_->data_size = data_size;
|
||||||
|
array_desc_->strides = std::move(strides);
|
||||||
|
array_desc_->flags = flags;
|
||||||
|
}
|
||||||
|
|
||||||
|
void array::copy_shared_buffer(
|
||||||
|
const array& other,
|
||||||
|
const std::vector<size_t>& strides,
|
||||||
|
Flags flags,
|
||||||
|
size_t data_size,
|
||||||
|
size_t offset /* = 0 */) {
|
||||||
|
array_desc_->data = other.array_desc_->data;
|
||||||
|
array_desc_->strides = strides;
|
||||||
|
array_desc_->flags = flags;
|
||||||
|
array_desc_->data_size = data_size;
|
||||||
|
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||||
|
array_desc_->data_ptr = static_cast<void*>(
|
||||||
|
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
void array::copy_shared_buffer(const array& other) {
|
||||||
|
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
||||||
|
: shape(shape), dtype(dtype) {
|
||||||
|
std::tie(size, strides) = cum_prod(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
array::ArrayDesc::ArrayDesc(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
Dtype dtype,
|
||||||
|
std::unique_ptr<Primitive> primitive,
|
||||||
|
const std::vector<array>& inputs)
|
||||||
|
: shape(shape),
|
||||||
|
dtype(dtype),
|
||||||
|
primitive(std::move(primitive)),
|
||||||
|
inputs(inputs) {
|
||||||
|
std::tie(size, strides) = cum_prod(shape);
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
is_tracer |= in.is_tracer();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Needed because the Primitive type used in array.h is incomplete and the
|
||||||
|
// compiler needs to see the call to the desctructor after the type is complete.
|
||||||
|
array::ArrayDesc::~ArrayDesc() = default;
|
||||||
|
|
||||||
|
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||||
|
auto start = std::vector<int>(arr.ndim(), 0);
|
||||||
|
auto end = arr.shape();
|
||||||
|
auto shape = arr.shape();
|
||||||
|
shape.erase(shape.begin());
|
||||||
|
start[0] = idx;
|
||||||
|
end[0] = idx + 1;
|
||||||
|
return reshape(slice(arr, start, end), shape);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
323
mlx/backend/accelerate/softmax.cpp
Normal file
323
mlx/backend/accelerate/softmax.cpp
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include <arm_neon.h>
|
||||||
|
#include <simd/math.h>
|
||||||
|
#include <simd/vector.h>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute exp(x) in an optimizer friendly way as follows:
|
||||||
|
*
|
||||||
|
* First change the problem to computing 2**y where y = x / ln(2).
|
||||||
|
*
|
||||||
|
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
|
||||||
|
* `ipart` and y2 is fractional part. For the integer part we perform bit
|
||||||
|
* shifting and for the fractional part we use a polynomial approximation.
|
||||||
|
*
|
||||||
|
* The algorithm and constants of the polynomial taken from
|
||||||
|
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
|
||||||
|
* from Cephes math library.
|
||||||
|
*
|
||||||
|
* Note: The implementation below is a general fast exp. There could be faster
|
||||||
|
* implementations for numbers strictly < 0.
|
||||||
|
*/
|
||||||
|
inline simd_float16 simd_fast_exp(simd_float16 x) {
|
||||||
|
x *= 1.442695; // multiply with log_2(e)
|
||||||
|
simd_float16 ipart, fpart;
|
||||||
|
simd_int16 epart;
|
||||||
|
x = simd_clamp(x, -80, 80);
|
||||||
|
ipart = simd::floor(x + 0.5);
|
||||||
|
fpart = x - ipart;
|
||||||
|
|
||||||
|
x = 1.535336188319500e-4f;
|
||||||
|
x = x * fpart + 1.339887440266574e-3f;
|
||||||
|
x = x * fpart + 9.618437357674640e-3f;
|
||||||
|
x = x * fpart + 5.550332471162809e-2f;
|
||||||
|
x = x * fpart + 2.402264791363012e-1f;
|
||||||
|
x = x * fpart + 6.931472028550421e-1f;
|
||||||
|
x = x * fpart + 1.000000000000000f;
|
||||||
|
|
||||||
|
// generate 2**ipart in the floating point representation using integer
|
||||||
|
// bitshifting
|
||||||
|
epart = (simd_int(ipart) + 127) << 23;
|
||||||
|
|
||||||
|
return (*(simd_float16*)&epart) * x;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The ARM neon equivalent of the fast exp above.
|
||||||
|
*/
|
||||||
|
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||||
|
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
|
||||||
|
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
|
||||||
|
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
|
||||||
|
|
||||||
|
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
|
||||||
|
float16x8_t fpart = vsubq_f16(x, ipart);
|
||||||
|
|
||||||
|
x = vdupq_n_f16(1.535336188319500e-4f);
|
||||||
|
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
||||||
|
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
||||||
|
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
|
||||||
|
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
|
||||||
|
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
|
||||||
|
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
|
||||||
|
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
|
||||||
|
|
||||||
|
// generate 2**ipart in the floating point representation using integer
|
||||||
|
// bitshifting
|
||||||
|
int16x8_t epart = vcvtq_s16_f16(ipart);
|
||||||
|
epart = vaddq_s16(epart, vdupq_n_s16(15));
|
||||||
|
epart = vshlq_n_s16(epart, 10);
|
||||||
|
|
||||||
|
return vmulq_f16(vreinterpretq_f16_s16(epart), x);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implementation of folding maximum for ARM neon. This should possibly be
|
||||||
|
* refactored out of softmax.cpp at some point.
|
||||||
|
*/
|
||||||
|
inline float16_t neon_reduce_max(float16x8_t x) {
|
||||||
|
float16x4_t y;
|
||||||
|
y = vpmax_f16(vget_low_f16(x), vget_high_f16(x));
|
||||||
|
y = vpmax_f16(y, y);
|
||||||
|
y = vpmax_f16(y, y);
|
||||||
|
return vget_lane_f16(y, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implementation of folding sum for ARM neon. This should possibly be
|
||||||
|
* refactored out of softmax.cpp at some point.
|
||||||
|
*/
|
||||||
|
inline float16_t neon_reduce_add(float16x8_t x) {
|
||||||
|
float16x4_t y;
|
||||||
|
float16x4_t zero = vdup_n_f16(0);
|
||||||
|
y = vpadd_f16(vget_low_f16(x), vget_high_f16(x));
|
||||||
|
y = vpadd_f16(y, zero);
|
||||||
|
y = vpadd_f16(y, zero);
|
||||||
|
return vget_lane_f16(y, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename VT>
|
||||||
|
struct AccelerateSimdOps {
|
||||||
|
VT init(T a) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
VT load(const T* a) {
|
||||||
|
return *(VT*)a;
|
||||||
|
}
|
||||||
|
|
||||||
|
void store(T* dst, VT x) {
|
||||||
|
*(VT*)dst = x;
|
||||||
|
}
|
||||||
|
|
||||||
|
VT max(VT a, VT b) {
|
||||||
|
return simd_max(a, b);
|
||||||
|
};
|
||||||
|
|
||||||
|
VT exp(VT x) {
|
||||||
|
return simd_fast_exp(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT add(VT a, VT b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
VT sub(VT a, T b) {
|
||||||
|
return a - b;
|
||||||
|
}
|
||||||
|
|
||||||
|
VT mul(VT a, VT b) {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
VT mul(VT a, T b) {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
T reduce_max(VT x) {
|
||||||
|
return simd_reduce_max(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
T reduce_add(VT x) {
|
||||||
|
return simd_reduce_add(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename VT>
|
||||||
|
struct NeonFp16SimdOps {
|
||||||
|
VT init(T a) {
|
||||||
|
return vdupq_n_f16(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT load(const T* a) {
|
||||||
|
return vld1q_f16(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
void store(T* dst, VT x) {
|
||||||
|
vst1q_f16(dst, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT max(VT a, VT b) {
|
||||||
|
return vmaxq_f16(a, b);
|
||||||
|
};
|
||||||
|
|
||||||
|
VT exp(VT x) {
|
||||||
|
return neon_fast_exp(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT add(VT a, VT b) {
|
||||||
|
return vaddq_f16(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT sub(VT a, T b) {
|
||||||
|
return vsubq_f16(a, vdupq_n_f16(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
VT mul(VT a, VT b) {
|
||||||
|
return vmulq_f16(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT mul(VT a, T b) {
|
||||||
|
return vmulq_f16(a, vdupq_n_f16(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
T reduce_max(VT x) {
|
||||||
|
return neon_reduce_max(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
T reduce_add(VT x) {
|
||||||
|
return neon_reduce_add(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename VT, typename Ops, int N>
|
||||||
|
void softmax(const array& in, array& out) {
|
||||||
|
Ops ops;
|
||||||
|
|
||||||
|
const T* in_ptr = in.data<T>();
|
||||||
|
T* out_ptr = out.data<T>();
|
||||||
|
int M = in.shape().back();
|
||||||
|
int L = in.data_size() / M;
|
||||||
|
const T* current_in_ptr;
|
||||||
|
T* current_out_ptr;
|
||||||
|
|
||||||
|
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||||
|
// Find the maximum
|
||||||
|
current_in_ptr = in_ptr;
|
||||||
|
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
|
||||||
|
size_t s = M;
|
||||||
|
while (s >= N) {
|
||||||
|
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
|
||||||
|
current_in_ptr += N;
|
||||||
|
s -= N;
|
||||||
|
}
|
||||||
|
T maximum = ops.reduce_max(vmaximum);
|
||||||
|
while (s-- > 0) {
|
||||||
|
maximum = std::max(maximum, *current_in_ptr);
|
||||||
|
current_in_ptr++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the normalizer and the exponentials
|
||||||
|
VT vnormalizer = ops.init(0.0);
|
||||||
|
current_out_ptr = out_ptr;
|
||||||
|
current_in_ptr = in_ptr;
|
||||||
|
s = M;
|
||||||
|
while (s >= N) {
|
||||||
|
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
|
||||||
|
ops.store(current_out_ptr, vexp);
|
||||||
|
*(VT*)current_out_ptr = vexp;
|
||||||
|
vnormalizer = ops.add(vnormalizer, vexp);
|
||||||
|
current_in_ptr += N;
|
||||||
|
current_out_ptr += N;
|
||||||
|
s -= N;
|
||||||
|
}
|
||||||
|
T normalizer = ops.reduce_add(vnormalizer);
|
||||||
|
while (s-- > 0) {
|
||||||
|
T _exp = std::exp(*current_in_ptr - maximum);
|
||||||
|
*current_out_ptr = _exp;
|
||||||
|
normalizer += _exp;
|
||||||
|
current_in_ptr++;
|
||||||
|
current_out_ptr++;
|
||||||
|
}
|
||||||
|
normalizer = 1 / normalizer;
|
||||||
|
|
||||||
|
// Normalize
|
||||||
|
current_out_ptr = out_ptr;
|
||||||
|
s = M;
|
||||||
|
while (s >= N) {
|
||||||
|
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||||
|
current_out_ptr += N;
|
||||||
|
s -= N;
|
||||||
|
}
|
||||||
|
while (s-- > 0) {
|
||||||
|
*current_out_ptr *= normalizer;
|
||||||
|
current_out_ptr++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
// Make sure that the last dimension is contiguous
|
||||||
|
auto check_input = [](array x) {
|
||||||
|
if (x.strides()[x.ndim() - 1] == 1) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy(x, x_copy, CopyType::General);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
array in = check_input(std::move(inputs[0]));
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
|
||||||
|
switch (in.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
case uint8:
|
||||||
|
case uint16:
|
||||||
|
case uint32:
|
||||||
|
case uint64:
|
||||||
|
case int8:
|
||||||
|
case int16:
|
||||||
|
case int32:
|
||||||
|
case int64:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Softmax is defined only for floating point types");
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
|
||||||
|
in, out);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
softmax<
|
||||||
|
float16_t,
|
||||||
|
float16x8_t,
|
||||||
|
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||||
|
8>(in, out);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
eval(inputs, out);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
eval(inputs, out);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
216
mlx/backend/common/binary.cpp
Normal file
216
mlx/backend/common/binary.cpp
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/binary.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||||
|
DefaultScalarVector<T, U, Op> opsv(op);
|
||||||
|
DefaultVectorScalar<T, U, Op> opvs(op);
|
||||||
|
DefaultVectorVector<T, U, Op> opvv(op);
|
||||||
|
binary_op<T, U>(a, b, out, op, opsv, opvs, opvv);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
comparison_op<bool, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
comparison_op<uint8_t, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
comparison_op<uint16_t, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
comparison_op<uint32_t, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
comparison_op<uint64_t, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
comparison_op<int8_t, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
comparison_op<int16_t, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
comparison_op<int32_t, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
comparison_op<int64_t, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
comparison_op<float16_t, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
comparison_op<float, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
comparison_op<bfloat16_t, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
comparison_op<complex64_t, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Add::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
if (equal_nan_) {
|
||||||
|
comparison_op(inputs[0], inputs[1], out, [](auto x, auto y) {
|
||||||
|
return x == y || (std::isnan(x) && std::isnan(y));
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
comparison_op(
|
||||||
|
inputs[0], inputs[1], out, [](auto x, auto y) { return x == y; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Greater::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
comparison_op(
|
||||||
|
inputs[0], inputs[1], out, [](auto x, auto y) { return x > y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
comparison_op(
|
||||||
|
inputs[0], inputs[1], out, [](auto x, auto y) { return x >= y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Less::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
comparison_op(
|
||||||
|
inputs[0], inputs[1], out, [](auto x, auto y) { return x < y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
comparison_op(
|
||||||
|
inputs[0], inputs[1], out, [](auto x, auto y) { return x <= y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto op = [](auto x, auto y) {
|
||||||
|
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||||
|
auto maxval = (x > y) ? x : y;
|
||||||
|
auto minval = (x > y) ? y : x;
|
||||||
|
return (minval == -inf || maxval == inf)
|
||||||
|
? maxval
|
||||||
|
: static_cast<decltype(x)>(
|
||||||
|
maxval + std::log1p(std::exp(minval - maxval)));
|
||||||
|
};
|
||||||
|
if (is_floating_point(out.dtype())) {
|
||||||
|
if (out.dtype() == float32) {
|
||||||
|
binary_op<float>(a, b, out, op);
|
||||||
|
} else if (out.dtype() == float16) {
|
||||||
|
binary_op<float16_t>(a, b, out, op);
|
||||||
|
} else if (out.dtype() == bfloat16) {
|
||||||
|
binary_op<bfloat16_t>(a, b, out, op);
|
||||||
|
} else {
|
||||||
|
std::ostringstream err;
|
||||||
|
err << "[logaddexp] Does not support " << out.dtype();
|
||||||
|
throw std::invalid_argument(err.str());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[logaddexp] Cannot compute logaddexp for arrays with"
|
||||||
|
" non floating point type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return x * y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
comparison_op(
|
||||||
|
inputs[0], inputs[1], out, [](auto x, auto y) { return x != y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PowerFn {
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||||
|
return std::pow(base, exp);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||||
|
if (exp < 0) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Integers cannot be raise to negative powers");
|
||||||
|
}
|
||||||
|
T res = 1;
|
||||||
|
while (exp) {
|
||||||
|
if (exp & 1) {
|
||||||
|
res *= base;
|
||||||
|
}
|
||||||
|
exp >>= 1;
|
||||||
|
base *= base;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void Power::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
binary(a, b, out, PowerFn{});
|
||||||
|
}
|
||||||
|
|
||||||
|
void Subtract::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
binary(a, b, out, [](auto x, auto y) { return x - y; });
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
554
mlx/backend/common/binary.h
Normal file
554
mlx/backend/common/binary.h
Normal file
@ -0,0 +1,554 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
enum BinaryOpType {
|
||||||
|
ScalarScalar,
|
||||||
|
ScalarVector,
|
||||||
|
VectorScalar,
|
||||||
|
VectorVector,
|
||||||
|
General,
|
||||||
|
};
|
||||||
|
|
||||||
|
BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||||
|
BinaryOpType bopt;
|
||||||
|
if (a.data_size() == 1 && b.data_size() == 1) {
|
||||||
|
bopt = ScalarScalar;
|
||||||
|
} else if (a.data_size() == 1 && b.flags().contiguous) {
|
||||||
|
bopt = ScalarVector;
|
||||||
|
} else if (b.data_size() == 1 && a.flags().contiguous) {
|
||||||
|
bopt = VectorScalar;
|
||||||
|
} else if (
|
||||||
|
a.flags().row_contiguous && b.flags().row_contiguous ||
|
||||||
|
a.flags().col_contiguous && b.flags().col_contiguous) {
|
||||||
|
bopt = VectorVector;
|
||||||
|
} else {
|
||||||
|
bopt = General;
|
||||||
|
}
|
||||||
|
return bopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_binary_op_output_data(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
BinaryOpType bopt) {
|
||||||
|
switch (bopt) {
|
||||||
|
case ScalarScalar:
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||||
|
break;
|
||||||
|
case ScalarVector:
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||||
|
b.data_size(),
|
||||||
|
b.strides(),
|
||||||
|
b.flags());
|
||||||
|
break;
|
||||||
|
case VectorScalar:
|
||||||
|
case VectorVector:
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||||
|
a.data_size(),
|
||||||
|
a.strides(),
|
||||||
|
a.flags());
|
||||||
|
break;
|
||||||
|
case General:
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct UseDefaultBinaryOp {
|
||||||
|
template <typename T, typename U>
|
||||||
|
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||||
|
// Should we throw? This should normally never be called.
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
struct DefaultVectorScalar {
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
DefaultVectorScalar(Op op_) : op(op_) {}
|
||||||
|
|
||||||
|
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||||
|
T scalar = *b;
|
||||||
|
while (size-- > 0) {
|
||||||
|
*dst = op(*a, scalar);
|
||||||
|
dst++;
|
||||||
|
a++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
struct DefaultScalarVector {
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
DefaultScalarVector(Op op_) : op(op_) {}
|
||||||
|
|
||||||
|
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||||
|
T scalar = *a;
|
||||||
|
while (size-- > 0) {
|
||||||
|
*dst = op(scalar, *b);
|
||||||
|
dst++;
|
||||||
|
b++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
struct DefaultVectorVector {
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
DefaultVectorVector(Op op_) : op(op_) {}
|
||||||
|
|
||||||
|
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||||
|
while (size-- > 0) {
|
||||||
|
*dst = op(*a, *b);
|
||||||
|
dst++;
|
||||||
|
a++;
|
||||||
|
b++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
for (size_t i = 0; i < out.size(); ++i) {
|
||||||
|
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
a_idx += a.strides()[0];
|
||||||
|
b_idx += b.strides()[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims1(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
int stride) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; i++) {
|
||||||
|
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||||
|
a_idx += a.strides()[0];
|
||||||
|
b_idx += b.strides()[0];
|
||||||
|
dst += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
a_idx += a.strides()[1];
|
||||||
|
b_idx += b.strides()[1];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims2(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
int stride) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||||
|
a_idx += a.strides()[1];
|
||||||
|
b_idx += b.strides()[1];
|
||||||
|
dst += stride;
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||||
|
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
a_idx += a.strides()[2];
|
||||||
|
b_idx += b.strides()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||||
|
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||||
|
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||||
|
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
a_idx += a.strides()[3];
|
||||||
|
b_idx += b.strides()[3];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||||
|
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||||
|
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dispatch_dims(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op) {
|
||||||
|
switch (out.ndim()) {
|
||||||
|
case 1:
|
||||||
|
binary_op_dims1<T, U, Op>(a, b, out, op);
|
||||||
|
return;
|
||||||
|
case 2:
|
||||||
|
binary_op_dims2<T, U, Op>(a, b, out, op);
|
||||||
|
return;
|
||||||
|
case 3:
|
||||||
|
binary_op_dims3<T, U, Op>(a, b, out, op);
|
||||||
|
return;
|
||||||
|
case 4:
|
||||||
|
binary_op_dims4<T, U, Op>(a, b, out, op);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
for (size_t i = 0; i < out.size(); i++) {
|
||||||
|
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||||
|
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||||
|
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dispatch_dims(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
int dim,
|
||||||
|
int stride) {
|
||||||
|
// Number of dimensions to loop over for vectorized ops
|
||||||
|
switch (dim) {
|
||||||
|
case 1:
|
||||||
|
binary_op_dims1<T, U, Op>(a, b, out, op, stride);
|
||||||
|
return;
|
||||||
|
case 2:
|
||||||
|
binary_op_dims2<T, U, Op>(a, b, out, op, stride);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
for (size_t i = 0; i < out.size(); i += stride) {
|
||||||
|
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||||
|
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||||
|
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||||
|
dst += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
typename OpSV,
|
||||||
|
typename OpVS,
|
||||||
|
typename OpVV>
|
||||||
|
void binary_op(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
OpSV opsv,
|
||||||
|
OpVS opvs,
|
||||||
|
OpVV opvv) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
// The full computation is scalar scalar so call the base op once
|
||||||
|
if (bopt == ScalarScalar) {
|
||||||
|
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The full computation is scalar vector so delegate to the op
|
||||||
|
if (bopt == ScalarVector) {
|
||||||
|
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The full computation is vector scalar so delegate to the op
|
||||||
|
if (bopt == VectorScalar) {
|
||||||
|
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The full computation is vector vector so delegate to the op
|
||||||
|
if (bopt == VectorVector) {
|
||||||
|
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// General computation so let's try to optimize
|
||||||
|
|
||||||
|
// Get the left-most dim such that the array is row contiguous after
|
||||||
|
auto& strides = out.strides();
|
||||||
|
auto leftmost_rc_dim = [&strides](const array& arr) {
|
||||||
|
int d = arr.ndim() - 1;
|
||||||
|
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
|
||||||
|
}
|
||||||
|
return d + 1;
|
||||||
|
};
|
||||||
|
auto a_rc_dim = leftmost_rc_dim(a);
|
||||||
|
auto b_rc_dim = leftmost_rc_dim(b);
|
||||||
|
|
||||||
|
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||||
|
auto leftmost_s_dim = [](const array& arr) {
|
||||||
|
int d = arr.ndim() - 1;
|
||||||
|
for (; d >= 0 && arr.strides()[d] == 0; d--) {
|
||||||
|
}
|
||||||
|
return d + 1;
|
||||||
|
};
|
||||||
|
auto a_s_dim = leftmost_s_dim(a);
|
||||||
|
auto b_s_dim = leftmost_s_dim(b);
|
||||||
|
|
||||||
|
auto ndim = out.ndim();
|
||||||
|
|
||||||
|
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||||
|
int dim = ndim;
|
||||||
|
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||||
|
bopt = VectorVector;
|
||||||
|
dim = d;
|
||||||
|
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||||
|
// contiguous
|
||||||
|
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||||
|
bopt = VectorScalar;
|
||||||
|
dim = d;
|
||||||
|
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||||
|
// contiguous
|
||||||
|
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||||
|
bopt = ScalarVector;
|
||||||
|
dim = d;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||||
|
// contiguous methods above. Except for the case that the flags do not
|
||||||
|
// correspond to the underlying contiguity.
|
||||||
|
size_t stride;
|
||||||
|
if (dim == 0 || strides[dim - 1] < 16) {
|
||||||
|
stride = 1;
|
||||||
|
bopt = General;
|
||||||
|
dim = ndim;
|
||||||
|
} else {
|
||||||
|
stride = strides[dim - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (bopt) {
|
||||||
|
case VectorVector:
|
||||||
|
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
|
||||||
|
break;
|
||||||
|
case VectorScalar:
|
||||||
|
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
|
||||||
|
break;
|
||||||
|
case ScalarVector:
|
||||||
|
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
binary_op_dispatch_dims<T, U>(a, b, out, op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
|
||||||
|
void binary_op(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
OpSV opsv,
|
||||||
|
OpVS opvs,
|
||||||
|
OpVV opvv) {
|
||||||
|
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||||
|
// with template specializations and overloading. Would it be simpler?
|
||||||
|
|
||||||
|
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||||
|
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||||
|
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
|
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
op,
|
||||||
|
DefaultScalarVector<T, T, Op>(op),
|
||||||
|
DefaultVectorScalar<T, T, Op>(op),
|
||||||
|
DefaultVectorVector<T, T, Op>(op));
|
||||||
|
} else {
|
||||||
|
// opsv and opvs were UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
op,
|
||||||
|
DefaultScalarVector<T, T, Op>(op),
|
||||||
|
DefaultVectorScalar<T, T, Op>(op),
|
||||||
|
opvv);
|
||||||
|
}
|
||||||
|
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
|
// opsv and opvv were UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
op,
|
||||||
|
DefaultScalarVector<T, T, Op>(op),
|
||||||
|
opvs,
|
||||||
|
DefaultVectorVector<T, T, Op>(op));
|
||||||
|
} else {
|
||||||
|
// opsv was UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
|
||||||
|
}
|
||||||
|
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||||
|
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
|
// opvs and opvv were UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
op,
|
||||||
|
opsv,
|
||||||
|
DefaultVectorScalar<T, T, Op>(op),
|
||||||
|
DefaultVectorVector<T, T, Op>(op));
|
||||||
|
} else {
|
||||||
|
// opvs was UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
|
||||||
|
}
|
||||||
|
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
|
// opvv was UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
|
||||||
|
} else {
|
||||||
|
// All ops provided
|
||||||
|
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||||
|
DefaultScalarVector<T, T, Op> opsv(op);
|
||||||
|
DefaultVectorScalar<T, T, Op> opvs(op);
|
||||||
|
DefaultVectorVector<T, T, Op> opvv(op);
|
||||||
|
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Ops>
|
||||||
|
void binary(const array& a, const array& b, array& out, Ops... ops) {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t>(a, b, out, ops...);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
85
mlx/backend/common/fft.cpp
Normal file
85
mlx/backend/common/fft.cpp
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "mlx/3rdparty/pocketfft.h"
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void FFT::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
auto& in = inputs[0];
|
||||||
|
std::vector<std::ptrdiff_t> strides_in(
|
||||||
|
in.strides().begin(), in.strides().end());
|
||||||
|
for (auto& s : strides_in) {
|
||||||
|
s *= in.itemsize();
|
||||||
|
}
|
||||||
|
std::vector<std::ptrdiff_t> strides_out(
|
||||||
|
out.strides().begin(), out.strides().end());
|
||||||
|
for (auto& s : strides_out) {
|
||||||
|
s *= out.itemsize();
|
||||||
|
}
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
std::vector<size_t> shape;
|
||||||
|
if (out.dtype() == float32) {
|
||||||
|
shape.insert(shape.end(), out.shape().begin(), out.shape().end());
|
||||||
|
} else {
|
||||||
|
shape.insert(shape.end(), in.shape().begin(), in.shape().end());
|
||||||
|
}
|
||||||
|
|
||||||
|
float scale = 1.0f;
|
||||||
|
if (inverse_) {
|
||||||
|
size_t nelem = std::accumulate(
|
||||||
|
axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) {
|
||||||
|
return x * shape[y];
|
||||||
|
});
|
||||||
|
scale /= nelem;
|
||||||
|
}
|
||||||
|
if (in.dtype() == complex64 && out.dtype() == complex64) {
|
||||||
|
auto in_ptr =
|
||||||
|
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||||
|
auto out_ptr =
|
||||||
|
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||||
|
pocketfft::c2c(
|
||||||
|
shape,
|
||||||
|
strides_in,
|
||||||
|
strides_out,
|
||||||
|
axes_,
|
||||||
|
!inverse_,
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
scale);
|
||||||
|
} else if (in.dtype() == float32 && out.dtype() == complex64) {
|
||||||
|
auto in_ptr = in.data<float>();
|
||||||
|
auto out_ptr =
|
||||||
|
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||||
|
pocketfft::r2c(
|
||||||
|
shape,
|
||||||
|
strides_in,
|
||||||
|
strides_out,
|
||||||
|
axes_,
|
||||||
|
!inverse_,
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
scale);
|
||||||
|
} else if (in.dtype() == complex64 && out.dtype() == float32) {
|
||||||
|
auto in_ptr =
|
||||||
|
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||||
|
auto out_ptr = out.data<float>();
|
||||||
|
pocketfft::c2r(
|
||||||
|
shape,
|
||||||
|
strides_in,
|
||||||
|
strides_out,
|
||||||
|
axes_,
|
||||||
|
!inverse_,
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
scale);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[FFT] Received unexpected input and output type combination.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
98
mlx/backend/common/softmax.cpp
Normal file
98
mlx/backend/common/softmax.cpp
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void softmax(const array& in, array& out) {
|
||||||
|
const T* in_ptr = in.data<T>();
|
||||||
|
T* out_ptr = out.data<T>();
|
||||||
|
int N = in.shape().back();
|
||||||
|
int M = in.data_size() / N;
|
||||||
|
const T* current_in_ptr;
|
||||||
|
T* current_out_ptr;
|
||||||
|
|
||||||
|
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
|
||||||
|
// Find the maximum
|
||||||
|
current_in_ptr = in_ptr;
|
||||||
|
T maximum = *current_in_ptr;
|
||||||
|
for (int j = 0; j < N; j++, current_in_ptr++) {
|
||||||
|
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the normalizer and the exponentials
|
||||||
|
T normalizer = 0;
|
||||||
|
current_out_ptr = out_ptr;
|
||||||
|
current_in_ptr = in_ptr;
|
||||||
|
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
|
||||||
|
T expv = std::exp(*current_in_ptr - maximum);
|
||||||
|
normalizer += expv;
|
||||||
|
*current_out_ptr = expv;
|
||||||
|
}
|
||||||
|
normalizer = 1 / normalizer;
|
||||||
|
|
||||||
|
// Normalize
|
||||||
|
current_out_ptr = out_ptr;
|
||||||
|
for (int j = 0; j < N; j++, current_out_ptr++) {
|
||||||
|
*current_out_ptr *= normalizer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
// Make sure that the last dimension is contiguous
|
||||||
|
auto check_input = [](array x) {
|
||||||
|
if (x.strides().back() == 1) {
|
||||||
|
return x;
|
||||||
|
} else {
|
||||||
|
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy(x, x_copy, CopyType::General);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
array in = check_input(std::move(inputs[0]));
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
|
||||||
|
switch (in.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
case uint8:
|
||||||
|
case uint16:
|
||||||
|
case uint32:
|
||||||
|
case uint64:
|
||||||
|
case int8:
|
||||||
|
case int16:
|
||||||
|
case int32:
|
||||||
|
case int64:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Softmax is defined only for floating point types");
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
softmax<float>(in, out);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
softmax<float16_t>(in, out);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
softmax<bfloat16_t>(in, out);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[Softmax] Not yet implemented for complex64");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
200
mlx/backend/metal/allocator.cpp
Normal file
200
mlx/backend/metal/allocator.cpp
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
#include "mlx/backend/metal/allocator.h"
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
|
||||||
|
#include <mach/vm_page_size.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace allocator {
|
||||||
|
|
||||||
|
Allocator& allocator() {
|
||||||
|
return metal::allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
void* Buffer::raw_ptr() {
|
||||||
|
return static_cast<MTL::Buffer*>(ptr_)->contents();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace allocator
|
||||||
|
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
BufferCache::BufferCache(MTL::Device* device)
|
||||||
|
: device_(device),
|
||||||
|
head_(nullptr),
|
||||||
|
tail_(nullptr),
|
||||||
|
pool_size_(0),
|
||||||
|
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
|
||||||
|
|
||||||
|
BufferCache::~BufferCache() {
|
||||||
|
clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferCache::clear() {
|
||||||
|
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||||
|
for (auto& [size, holder] : buffer_pool_) {
|
||||||
|
if (holder->buf)
|
||||||
|
holder->buf->release();
|
||||||
|
delete holder;
|
||||||
|
}
|
||||||
|
buffer_pool_.clear();
|
||||||
|
pool_size_ = 0;
|
||||||
|
head_ = nullptr;
|
||||||
|
tail_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||||
|
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||||
|
|
||||||
|
// Find the closest buffer in pool
|
||||||
|
MTL::Buffer* pbuf = nullptr;
|
||||||
|
auto it = buffer_pool_.lower_bound(size);
|
||||||
|
|
||||||
|
// Make sure we use > 50% of the available memory
|
||||||
|
while (!pbuf && it != buffer_pool_.end() && it->first < 2 * size) {
|
||||||
|
// Collect from the cache
|
||||||
|
pbuf = it->second->buf;
|
||||||
|
// Remove from cache
|
||||||
|
remove_from_list(it->second);
|
||||||
|
delete it->second;
|
||||||
|
it = buffer_pool_.erase(it);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pbuf) {
|
||||||
|
pool_size_ -= pbuf->length();
|
||||||
|
}
|
||||||
|
|
||||||
|
return pbuf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
||||||
|
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||||
|
|
||||||
|
// Add to cache
|
||||||
|
if (buf) {
|
||||||
|
BufferHolder* bh = new BufferHolder(buf);
|
||||||
|
add_at_head(bh);
|
||||||
|
pool_size_ += buf->length();
|
||||||
|
buffer_pool_.insert({buf->length(), bh});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||||
|
min_bytes_to_free += device_->currentAllocatedSize() - gc_limit_;
|
||||||
|
|
||||||
|
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||||
|
size_t old_pool_size = pool_size_;
|
||||||
|
clear();
|
||||||
|
return old_pool_size;
|
||||||
|
} else {
|
||||||
|
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||||
|
size_t total_bytes_freed = 0;
|
||||||
|
|
||||||
|
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||||
|
if (tail_->buf) {
|
||||||
|
total_bytes_freed += tail_->buf->length();
|
||||||
|
tail_->buf->release();
|
||||||
|
tail_->buf = nullptr;
|
||||||
|
}
|
||||||
|
remove_from_list(tail_);
|
||||||
|
}
|
||||||
|
|
||||||
|
pool_size_ -= total_bytes_freed;
|
||||||
|
return total_bytes_freed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
|
||||||
|
if (!to_add)
|
||||||
|
return;
|
||||||
|
|
||||||
|
if (!head_) {
|
||||||
|
head_ = to_add;
|
||||||
|
tail_ = to_add;
|
||||||
|
} else {
|
||||||
|
head_->prev = to_add;
|
||||||
|
to_add->next = head_;
|
||||||
|
head_ = to_add;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||||
|
if (!to_remove)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// If in the middle
|
||||||
|
if (to_remove->prev && to_remove->next) {
|
||||||
|
to_remove->prev->next = to_remove->next;
|
||||||
|
to_remove->next->prev = to_remove->prev;
|
||||||
|
} else if (to_remove->prev && to_remove == tail_) { // If tail
|
||||||
|
tail_ = to_remove->prev;
|
||||||
|
tail_->next = nullptr;
|
||||||
|
} else if (to_remove == head_ && to_remove->next) { // If head
|
||||||
|
head_ = to_remove->next;
|
||||||
|
head_->prev = nullptr;
|
||||||
|
} else if (to_remove == head_ && to_remove == tail_) { // If only element
|
||||||
|
head_ = nullptr;
|
||||||
|
tail_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
to_remove->prev = nullptr;
|
||||||
|
to_remove->next = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MetalAllocator::MetalAllocator()
|
||||||
|
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||||
|
buffer_cache_(device_),
|
||||||
|
peak_allocated_size_(0),
|
||||||
|
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()) {}
|
||||||
|
|
||||||
|
Buffer MetalAllocator::malloc(size_t size) {
|
||||||
|
// Align up memory
|
||||||
|
if (size > vm_page_size) {
|
||||||
|
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
|
|
||||||
|
// Prepare to allocate new memory as needed
|
||||||
|
if (!buf) {
|
||||||
|
// If we are under very high memoory pressure, we don't allocate further
|
||||||
|
if (device_->currentAllocatedSize() >= block_limit_) {
|
||||||
|
return Buffer{nullptr};
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we are still under memory pressure, try cleaning cache
|
||||||
|
if (buffer_cache_.can_garbage_collect()) {
|
||||||
|
buffer_cache_.release_cached_buffers(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate new buffer if needed
|
||||||
|
size_t res_opt = MTL::ResourceStorageModeShared;
|
||||||
|
res_opt |= MTL::ResourceHazardTrackingModeTracked;
|
||||||
|
buf = device_->newBuffer(size, res_opt);
|
||||||
|
}
|
||||||
|
|
||||||
|
peak_allocated_size_ =
|
||||||
|
std::max(peak_allocated_size_, device_->currentAllocatedSize());
|
||||||
|
|
||||||
|
return Buffer{static_cast<void*>(buf)};
|
||||||
|
}
|
||||||
|
|
||||||
|
void MetalAllocator::free(Buffer buffer) {
|
||||||
|
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||||
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
MetalAllocator& allocator() {
|
||||||
|
static MetalAllocator allocator_;
|
||||||
|
return allocator_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
76
mlx/backend/metal/allocator.h
Normal file
76
mlx/backend/metal/allocator.h
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <mutex>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
|
using allocator::Buffer;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class BufferCache {
|
||||||
|
public:
|
||||||
|
BufferCache(MTL::Device* device);
|
||||||
|
~BufferCache();
|
||||||
|
void clear();
|
||||||
|
|
||||||
|
MTL::Buffer* reuse_from_cache(size_t size);
|
||||||
|
void recycle_to_cache(MTL::Buffer* buf);
|
||||||
|
size_t release_cached_buffers(size_t min_bytes_to_free);
|
||||||
|
|
||||||
|
bool can_garbage_collect() {
|
||||||
|
return pool_size_ > 0 && device_->currentAllocatedSize() > gc_limit_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct BufferHolder {
|
||||||
|
public:
|
||||||
|
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
|
||||||
|
|
||||||
|
BufferHolder* prev;
|
||||||
|
BufferHolder* next;
|
||||||
|
MTL::Buffer* buf;
|
||||||
|
};
|
||||||
|
|
||||||
|
void add_at_head(BufferHolder* to_add);
|
||||||
|
void remove_from_list(BufferHolder* to_remove);
|
||||||
|
|
||||||
|
MTL::Device* device_;
|
||||||
|
std::mutex cache_mutex_;
|
||||||
|
|
||||||
|
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||||
|
BufferHolder* head_;
|
||||||
|
BufferHolder* tail_;
|
||||||
|
size_t pool_size_;
|
||||||
|
size_t gc_limit_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class MetalAllocator : public allocator::Allocator {
|
||||||
|
/** Allocator for Metal GPUs. */
|
||||||
|
public:
|
||||||
|
virtual Buffer malloc(size_t size) override;
|
||||||
|
virtual void free(Buffer buffer) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
MTL::Device* device_;
|
||||||
|
MetalAllocator();
|
||||||
|
friend MetalAllocator& allocator();
|
||||||
|
|
||||||
|
// Caching allocator
|
||||||
|
BufferCache buffer_cache_;
|
||||||
|
|
||||||
|
// Allocation stats
|
||||||
|
size_t peak_allocated_size_;
|
||||||
|
size_t block_limit_;
|
||||||
|
};
|
||||||
|
|
||||||
|
MetalAllocator& allocator();
|
||||||
|
|
||||||
|
} // namespace mlx::core::metal
|
555
mlx/backend/metal/conv.cpp
Normal file
555
mlx/backend/metal/conv.cpp
Normal file
@ -0,0 +1,555 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
#include <numeric>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/copy.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/matmul.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void explicit_gemm_conv_1D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const MLXConvParams<1>& conv_params) {
|
||||||
|
// Pad input
|
||||||
|
std::vector<int> padded_shape = {
|
||||||
|
conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.C};
|
||||||
|
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||||
|
|
||||||
|
// Fill with zeros
|
||||||
|
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
||||||
|
|
||||||
|
// Pick input slice from padded
|
||||||
|
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
|
||||||
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
|
in_padded_slice.copy_shared_buffer(
|
||||||
|
in_padded,
|
||||||
|
in_padded.strides(),
|
||||||
|
in_padded.flags(),
|
||||||
|
in_padded_slice.size(),
|
||||||
|
data_offset);
|
||||||
|
|
||||||
|
// Copy input values into the slice
|
||||||
|
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||||
|
|
||||||
|
// Make strided view
|
||||||
|
std::vector<int> strided_shape = {
|
||||||
|
conv_params.N, conv_params.oS[0], conv_params.wS[0], conv_params.C};
|
||||||
|
|
||||||
|
std::vector<size_t> strided_strides = {
|
||||||
|
in_padded.strides()[0],
|
||||||
|
in_padded.strides()[1] * conv_params.str[0],
|
||||||
|
in_padded.strides()[1],
|
||||||
|
in_padded.strides()[2]};
|
||||||
|
auto flags = in_padded.flags();
|
||||||
|
|
||||||
|
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||||
|
in_strided_view.copy_shared_buffer(
|
||||||
|
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||||
|
|
||||||
|
// Materialize strided view
|
||||||
|
std::vector<int> strided_reshape = {
|
||||||
|
conv_params.N * conv_params.oS[0], conv_params.wS[0] * conv_params.C};
|
||||||
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
|
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||||
|
|
||||||
|
// Peform gemm
|
||||||
|
std::vector<array> copies = {in_padded, in_strided};
|
||||||
|
mlx_matmul(
|
||||||
|
s,
|
||||||
|
d,
|
||||||
|
/*a = */ in_strided,
|
||||||
|
/*b = */ wt,
|
||||||
|
/*c = */ out,
|
||||||
|
/*M = */ strided_reshape[0],
|
||||||
|
/*N = */ conv_params.O,
|
||||||
|
/*K = */ strided_reshape[1],
|
||||||
|
/*batch_size_out = */ 1,
|
||||||
|
/*a_cols = */ strided_reshape[1],
|
||||||
|
/*b_cols = */ strided_reshape[1],
|
||||||
|
/*a_transposed = */ false,
|
||||||
|
/*b_transposed = */ true,
|
||||||
|
/*copies = */ copies);
|
||||||
|
}
|
||||||
|
|
||||||
|
void conv_1D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation) {
|
||||||
|
// Make conv params
|
||||||
|
MLXConvParams<1> conv_params{
|
||||||
|
/* const int N = */ in.shape(0),
|
||||||
|
/* const int C = */ in.shape(2),
|
||||||
|
/* const int O = */ wt.shape(0),
|
||||||
|
/* const int iS[NDIM] = */ {in.shape(1)},
|
||||||
|
/* const int wS[NDIM] = */ {wt.shape(1)},
|
||||||
|
/* const int oS[NDIM] = */ {out.shape(1)},
|
||||||
|
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||||
|
/* const int pad[NDIM] = */ {padding[0]},
|
||||||
|
/* const int dil[NDIM] = */ {wt_dilation[0]},
|
||||||
|
/* const size_t in_strides[NDIM + 2] = */
|
||||||
|
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
||||||
|
/* const size_t wt_strides[NDIM + 2] = */
|
||||||
|
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||||
|
/* const size_t out_strides[NDIM + 2] = */
|
||||||
|
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Direct to explicit gemm conv
|
||||||
|
if (wt_dilation[0] == 1) {
|
||||||
|
explicit_gemm_conv_1D_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct to fallback conv
|
||||||
|
else {
|
||||||
|
throw std::invalid_argument("[conv_1D_gpu] Dilation needs to be 1.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void slow_conv_2D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const MLXConvParams<2>& conv_params) {
|
||||||
|
int bm = 16, bn = 8;
|
||||||
|
int tm = 4, tn = 4;
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "naive_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" << bn
|
||||||
|
<< "_tm" << tm << "_tn" << tn;
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
|
||||||
|
|
||||||
|
size_t grid_dim_x = (n_pixels + (tm * bm) - 1) / (tm * bm);
|
||||||
|
size_t grid_dim_y = (conv_params.O + (tn * bn) - 1) / (tn * bn);
|
||||||
|
size_t grid_dim_z = conv_params.N;
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(bm, bn, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, wt, 1);
|
||||||
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void implicit_gemm_conv_2D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const MLXConvParams<2>& conv_params) {
|
||||||
|
int bm = 32, bn = 32, bk = 16;
|
||||||
|
int wm = 2, wn = 2;
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
|
||||||
|
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||||
|
int implicit_N = conv_params.O;
|
||||||
|
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
||||||
|
|
||||||
|
size_t grid_dim_x = (implicit_N + bn - 1) / bn;
|
||||||
|
size_t grid_dim_y = (implicit_M + bm - 1) / bm;
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, wt, 1);
|
||||||
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void explicit_gemm_conv_2D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const MLXConvParams<2>& conv_params) {
|
||||||
|
// Pad input
|
||||||
|
std::vector<int> padded_shape = {
|
||||||
|
conv_params.N,
|
||||||
|
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||||
|
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||||
|
conv_params.C};
|
||||||
|
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||||
|
|
||||||
|
// Fill with zeros
|
||||||
|
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
||||||
|
|
||||||
|
// Pick input slice from padded
|
||||||
|
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
||||||
|
conv_params.pad[1] * in_padded.strides()[2];
|
||||||
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
|
in_padded_slice.copy_shared_buffer(
|
||||||
|
in_padded,
|
||||||
|
in_padded.strides(),
|
||||||
|
in_padded.flags(),
|
||||||
|
in_padded_slice.size(),
|
||||||
|
data_offset);
|
||||||
|
|
||||||
|
// Copy input values into the slice
|
||||||
|
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||||
|
|
||||||
|
// Make strided view
|
||||||
|
std::vector<int> strided_shape = {
|
||||||
|
conv_params.N,
|
||||||
|
conv_params.oS[0],
|
||||||
|
conv_params.oS[1],
|
||||||
|
conv_params.wS[0],
|
||||||
|
conv_params.wS[1],
|
||||||
|
conv_params.C};
|
||||||
|
|
||||||
|
std::vector<size_t> strided_strides = {
|
||||||
|
in_padded.strides()[0],
|
||||||
|
in_padded.strides()[1] * conv_params.str[0],
|
||||||
|
in_padded.strides()[2] * conv_params.str[1],
|
||||||
|
in_padded.strides()[1],
|
||||||
|
in_padded.strides()[2],
|
||||||
|
in_padded.strides()[3]};
|
||||||
|
auto flags = in_padded.flags();
|
||||||
|
|
||||||
|
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||||
|
in_strided_view.copy_shared_buffer(
|
||||||
|
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||||
|
|
||||||
|
// Materialize strided view
|
||||||
|
std::vector<int> strided_reshape = {
|
||||||
|
conv_params.N * conv_params.oS[0] * conv_params.oS[1],
|
||||||
|
conv_params.wS[0] * conv_params.wS[1] * conv_params.C};
|
||||||
|
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||||
|
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||||
|
|
||||||
|
// Peform gemm
|
||||||
|
std::vector<array> copies = {in_padded, in_strided};
|
||||||
|
mlx_matmul(
|
||||||
|
s,
|
||||||
|
d,
|
||||||
|
/*a = */ in_strided,
|
||||||
|
/*b = */ wt,
|
||||||
|
/*c = */ out,
|
||||||
|
/*M = */ strided_reshape[0],
|
||||||
|
/*N = */ conv_params.O,
|
||||||
|
/*K = */ strided_reshape[1],
|
||||||
|
/*batch_size_out = */ 1,
|
||||||
|
/*a_cols = */ strided_reshape[1],
|
||||||
|
/*b_cols = */ strided_reshape[1],
|
||||||
|
/*a_transposed = */ false,
|
||||||
|
/*b_transposed = */ true,
|
||||||
|
/*copies = */ copies);
|
||||||
|
}
|
||||||
|
|
||||||
|
void winograd_conv_2D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const MLXConvParams<2>& conv_params,
|
||||||
|
std::vector<array>& copies_w) {
|
||||||
|
std::vector<int> padded_shape = {
|
||||||
|
conv_params.N,
|
||||||
|
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||||
|
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||||
|
conv_params.C};
|
||||||
|
|
||||||
|
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
||||||
|
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
||||||
|
|
||||||
|
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||||
|
|
||||||
|
// Fill with zeros
|
||||||
|
array zero_arr = array(0, in.dtype());
|
||||||
|
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
|
||||||
|
|
||||||
|
// Pick input slice from padded
|
||||||
|
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
||||||
|
conv_params.pad[1] * in_padded.strides()[2];
|
||||||
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
|
in_padded_slice.copy_shared_buffer(
|
||||||
|
in_padded,
|
||||||
|
in_padded.strides(),
|
||||||
|
in_padded.flags(),
|
||||||
|
in_padded_slice.size(),
|
||||||
|
data_offset);
|
||||||
|
|
||||||
|
// Copy input values into the slice
|
||||||
|
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||||
|
|
||||||
|
copies_w.push_back(in_padded_slice);
|
||||||
|
copies_w.push_back(in_padded);
|
||||||
|
copies_w.push_back(zero_arr);
|
||||||
|
|
||||||
|
MLXConvParams<2> conv_params_updated{
|
||||||
|
/* const int N = */ in_padded.shape(0),
|
||||||
|
/* const int C = */ in_padded.shape(3),
|
||||||
|
/* const int O = */ wt.shape(0),
|
||||||
|
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
|
||||||
|
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
||||||
|
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||||
|
/* const int str[NDIM] = */ {1, 1},
|
||||||
|
/* const int pad[NDIM] = */ {0, 0},
|
||||||
|
/* const int dil[NDIM] = */ {1, 1},
|
||||||
|
/* const size_t in_strides[NDIM + 2] = */
|
||||||
|
{in_padded.strides()[0],
|
||||||
|
in_padded.strides()[1],
|
||||||
|
in_padded.strides()[2],
|
||||||
|
in_padded.strides()[3]},
|
||||||
|
/* const size_t wt_strides[NDIM + 2] = */
|
||||||
|
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||||
|
/* const size_t out_strides[NDIM + 2] = */
|
||||||
|
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||||
|
};
|
||||||
|
|
||||||
|
int O_c = conv_params.O;
|
||||||
|
int C_c = conv_params.C;
|
||||||
|
|
||||||
|
int N_tiles_n = conv_params.N;
|
||||||
|
int N_tiles_h = (conv_params.oS[0] + 5) / 6;
|
||||||
|
int N_tiles_w = (conv_params.oS[1] + 5) / 6;
|
||||||
|
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
|
||||||
|
|
||||||
|
// Do filter transform
|
||||||
|
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
||||||
|
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
|
||||||
|
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
|
||||||
|
copies_w.push_back(filt_wg);
|
||||||
|
{
|
||||||
|
int bc = 32;
|
||||||
|
int bo = 4;
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||||
|
<< bc;
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, wt, 0);
|
||||||
|
set_array_buffer(compute_encoder, filt_wg, 1);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&C_c, sizeof(int), 2);
|
||||||
|
compute_encoder->setBytes(&O_c, sizeof(int), 3);
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(32, bo, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do input transform
|
||||||
|
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
||||||
|
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
|
||||||
|
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
|
||||||
|
copies_w.push_back(inp_wg);
|
||||||
|
{
|
||||||
|
int bc = 32;
|
||||||
|
int wm = 2;
|
||||||
|
int wn = 2;
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
||||||
|
<< bc;
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, in_padded, 0);
|
||||||
|
set_array_buffer(compute_encoder, inp_wg, 1);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do batched gemm
|
||||||
|
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
||||||
|
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
|
||||||
|
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
|
||||||
|
copies_w.push_back(out_wg);
|
||||||
|
{
|
||||||
|
std::vector<array> empty_copies;
|
||||||
|
mlx_matmul(
|
||||||
|
s,
|
||||||
|
d,
|
||||||
|
/*a = */ inp_wg,
|
||||||
|
/*b = */ filt_wg,
|
||||||
|
/*c = */ out_wg,
|
||||||
|
/*M = */ N_tiles,
|
||||||
|
/*N = */ conv_params.O,
|
||||||
|
/*K = */ conv_params.C,
|
||||||
|
/*batch_size_out = */ 8 * 8,
|
||||||
|
/*a_cols = */ conv_params.C,
|
||||||
|
/*b_cols = */ conv_params.O,
|
||||||
|
/*a_transposed = */ false,
|
||||||
|
/*b_transposed = */ false,
|
||||||
|
/*copies = */ empty_copies);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do output transform
|
||||||
|
{
|
||||||
|
int bc = 32;
|
||||||
|
int wm = 2;
|
||||||
|
int wn = 2;
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
||||||
|
<< bc;
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, out_wg, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void conv_2D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
const array& wt,
|
||||||
|
array out,
|
||||||
|
const std::vector<int>& padding,
|
||||||
|
const std::vector<int>& wt_strides,
|
||||||
|
const std::vector<int>& wt_dilation,
|
||||||
|
std::vector<array>& copies) {
|
||||||
|
// Make conv params
|
||||||
|
MLXConvParams<2> conv_params{
|
||||||
|
/* const int N = */ in.shape(0),
|
||||||
|
/* const int C = */ in.shape(3),
|
||||||
|
/* const int O = */ wt.shape(0),
|
||||||
|
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
|
||||||
|
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
|
||||||
|
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||||
|
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
|
||||||
|
/* const int pad[NDIM] = */ {padding[0], padding[1]},
|
||||||
|
/* const int dil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||||
|
/* const size_t in_strides[NDIM + 2] = */
|
||||||
|
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
|
||||||
|
/* const size_t wt_strides[NDIM + 2] = */
|
||||||
|
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||||
|
/* const size_t out_strides[NDIM + 2] = */
|
||||||
|
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Direct to winograd conv
|
||||||
|
if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
|
||||||
|
conv_params.C >= 64 && conv_params.O >= 64 && conv_params.wS[0] == 3 &&
|
||||||
|
conv_params.wS[1] == 3 && conv_params.str[0] == 1 &&
|
||||||
|
conv_params.str[1] == 1 && conv_params.dil[0] == 1 &&
|
||||||
|
conv_params.dil[1] == 1) {
|
||||||
|
winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct to implicit gemm conv
|
||||||
|
else if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0) {
|
||||||
|
implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct to explicit gemm conv
|
||||||
|
else if (wt_dilation[0] == 1 && wt_dilation[1] == 1) {
|
||||||
|
explicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct to fallback conv
|
||||||
|
else {
|
||||||
|
slow_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
// Ensure contiguity
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto in = inputs[0];
|
||||||
|
auto wt = inputs[1];
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||||
|
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||||
|
copies.push_back(arr_copy);
|
||||||
|
in = arr_copy;
|
||||||
|
}
|
||||||
|
if (!wt.flags().row_contiguous) {
|
||||||
|
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
|
||||||
|
copy_gpu(wt, arr_copy, CopyType::General, s);
|
||||||
|
copies.push_back(arr_copy);
|
||||||
|
wt = arr_copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2D conv
|
||||||
|
if (out.ndim() == 4) {
|
||||||
|
conv_2D_gpu(
|
||||||
|
s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_, copies);
|
||||||
|
}
|
||||||
|
// 1D conv
|
||||||
|
else if (out.ndim() == 3) {
|
||||||
|
conv_1D_gpu(s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||||
|
}
|
||||||
|
// Throw error
|
||||||
|
else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[Convolution::eval_gpu] Only supports 1D or 2D convolutions.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear copies
|
||||||
|
if (copies.size() > 0) {
|
||||||
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
16
mlx/backend/metal/copy.h
Normal file
16
mlx/backend/metal/copy.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
||||||
|
void copy_gpu(const array& src, array& out, CopyType ctype);
|
||||||
|
void copy_gpu_inplace(
|
||||||
|
const array& src,
|
||||||
|
array& out,
|
||||||
|
CopyType ctype,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
30
mlx/backend/metal/kernels/arange.metal
Normal file
30
mlx/backend/metal/kernels/arange.metal
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[kernel]] void arange(
|
||||||
|
constant const T& start,
|
||||||
|
constant const T& step,
|
||||||
|
device T* out,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
out[index] = start + index * step;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_arange(tname, type) \
|
||||||
|
template [[host_name("arange" #tname)]] \
|
||||||
|
[[kernel]] void arange<type>( \
|
||||||
|
constant const type& start, \
|
||||||
|
constant const type& step, \
|
||||||
|
device type* out, \
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
instantiate_arange(uint8, uint8_t)
|
||||||
|
instantiate_arange(uint16, uint16_t)
|
||||||
|
instantiate_arange(uint32, uint32_t)
|
||||||
|
instantiate_arange(uint64, uint64_t)
|
||||||
|
instantiate_arange(int8, int8_t)
|
||||||
|
instantiate_arange(int16, int16_t)
|
||||||
|
instantiate_arange(int32, int32_t)
|
||||||
|
instantiate_arange(int64, int64_t)
|
||||||
|
instantiate_arange(float16, half)
|
||||||
|
instantiate_arange(float32, float)
|
||||||
|
instantiate_arange(bfloat16, bfloat16_t)
|
208
mlx/backend/metal/kernels/arg_reduce.metal
Normal file
208
mlx/backend/metal/kernels/arg_reduce.metal
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
#include <metal_atomic>
|
||||||
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct IndexValPair {
|
||||||
|
uint32_t index;
|
||||||
|
U val;
|
||||||
|
|
||||||
|
IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct ArgMin {
|
||||||
|
static constexpr constant U init = Limits<U>::max;
|
||||||
|
|
||||||
|
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||||
|
if (best.val > current.val || (best.val == current.val && best.index > current.index)) {
|
||||||
|
return current;
|
||||||
|
} else {
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||||
|
for (int i=0; i<N; i++) {
|
||||||
|
if (vals[i] < best.val) {
|
||||||
|
best.val = vals[i];
|
||||||
|
best.index = offset+i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
struct ArgMax {
|
||||||
|
static constexpr constant U init = Limits<U>::min;
|
||||||
|
|
||||||
|
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
|
||||||
|
if (best.val < current.val || (best.val == current.val && best.index > current.index)) {
|
||||||
|
return current;
|
||||||
|
} else {
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
|
||||||
|
for (int i=0; i<N; i++) {
|
||||||
|
if (vals[i] > best.val) {
|
||||||
|
best.val = vals[i];
|
||||||
|
best.index = offset+i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
bool simd_shuffle_down(bool data, uint16_t delta) {
|
||||||
|
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
||||||
|
return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
||||||
|
return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||||
|
return IndexValPair<U>(
|
||||||
|
simd_shuffle_down(data.index, delta),
|
||||||
|
simd_shuffle_down(data.val, delta)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T, typename Op, int N_READS>
|
||||||
|
[[kernel]] void arg_reduce_general(
|
||||||
|
const device T *in [[buffer(0)]],
|
||||||
|
device uint32_t *out [[buffer(1)]],
|
||||||
|
const device int *shape [[buffer(2)]],
|
||||||
|
const device size_t *in_strides [[buffer(3)]],
|
||||||
|
const device size_t *out_strides [[buffer(4)]],
|
||||||
|
const device size_t& ndim [[buffer(5)]],
|
||||||
|
const device size_t& axis_stride [[buffer(6)]],
|
||||||
|
const device size_t& axis_size [[buffer(7)]],
|
||||||
|
threadgroup IndexValPair<T> *local_data [[threadgroup(0)]],
|
||||||
|
uint gid [[thread_position_in_grid]],
|
||||||
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
|
uint lsize [[threads_per_threadgroup]],
|
||||||
|
uint simd_size [[threads_per_simdgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
// Shapes and strides *do not* contain the reduction axis. The reduction size
|
||||||
|
// and stride are provided in axis_stride and axis_size.
|
||||||
|
//
|
||||||
|
// Note: in shape == out shape with this convention.
|
||||||
|
//
|
||||||
|
// The sketch of the kernel is as follows.
|
||||||
|
// 1. Launch prod(shape) * thread_group_size threads.
|
||||||
|
// 2. Loop ceildiv(axis_size / lsize) times
|
||||||
|
// 3. Read input values
|
||||||
|
// 4. Reduce among them and go to 3
|
||||||
|
// 4. Reduce in each simd_group
|
||||||
|
// 6. Write in the thread local memory
|
||||||
|
// 6. Reduce them accross thread group
|
||||||
|
// 7. Write the output without need for atomic
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
// Compute the input/output index. There is one beginning and one output for
|
||||||
|
// the whole threadgroup.
|
||||||
|
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
|
||||||
|
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
|
||||||
|
|
||||||
|
IndexValPair<T> best(0, Op::init);
|
||||||
|
|
||||||
|
// Loop over the reduction axis in lsize*N_READS buckets
|
||||||
|
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
||||||
|
// Read the current value
|
||||||
|
uint32_t current_index = r*lsize*N_READS + lid*N_READS;
|
||||||
|
uint32_t offset = current_index;
|
||||||
|
const device T * current_in = in + in_idx + current_index * axis_stride;
|
||||||
|
T vals[N_READS];
|
||||||
|
for (int i=0; i<N_READS; i++) {
|
||||||
|
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
|
||||||
|
current_index++;
|
||||||
|
current_in += axis_stride;
|
||||||
|
}
|
||||||
|
best = op.template reduce_many<N_READS>(best, vals, offset);
|
||||||
|
}
|
||||||
|
// At this point we have reduced the axis into thread group best values so we
|
||||||
|
// need to reduce across the thread group.
|
||||||
|
|
||||||
|
// First per simd reduction.
|
||||||
|
for (uint offset=simd_size/2; offset>0; offset/=2) {
|
||||||
|
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
||||||
|
best = op.reduce(best, neighbor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write to the threadgroup memory
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_data[simd_group_id] = best;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (simd_group_id != 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the appropriate value from local data and perform one simd reduction
|
||||||
|
uint simd_groups = ceildiv(lsize, simd_size);
|
||||||
|
if (simd_lane_id < simd_groups) {
|
||||||
|
best = local_data[simd_lane_id];
|
||||||
|
}
|
||||||
|
for (uint offset=simd_size/2; offset>0; offset/=2) {
|
||||||
|
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
|
||||||
|
best = op.reduce(best, neighbor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally write the output
|
||||||
|
if (lid == 0) {
|
||||||
|
out[out_idx] = best.index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_arg_reduce_helper(name, itype, op) \
|
||||||
|
template [[host_name(name)]] \
|
||||||
|
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \
|
||||||
|
const device itype *in [[buffer(0)]], \
|
||||||
|
device uint32_t * out [[buffer(1)]], \
|
||||||
|
const device int *shape [[buffer(2)]], \
|
||||||
|
const device size_t *in_strides [[buffer(3)]], \
|
||||||
|
const device size_t *out_strides [[buffer(4)]], \
|
||||||
|
const device size_t& ndim [[buffer(5)]], \
|
||||||
|
const device size_t& axis_stride [[buffer(6)]], \
|
||||||
|
const device size_t& axis_size [[buffer(7)]], \
|
||||||
|
threadgroup IndexValPair<itype> *local_data [[threadgroup(0)]], \
|
||||||
|
uint gid [[thread_position_in_grid]], \
|
||||||
|
uint lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint lsize [[threads_per_threadgroup]], \
|
||||||
|
uint simd_size [[threads_per_simdgroup]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_arg_reduce(name, itype) \
|
||||||
|
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
|
||||||
|
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
|
||||||
|
|
||||||
|
instantiate_arg_reduce(bool_, bool)
|
||||||
|
instantiate_arg_reduce(uint8, uint8_t)
|
||||||
|
instantiate_arg_reduce(uint16, uint16_t)
|
||||||
|
instantiate_arg_reduce(uint32, uint32_t)
|
||||||
|
instantiate_arg_reduce(uint64, uint64_t)
|
||||||
|
instantiate_arg_reduce(int8, int8_t)
|
||||||
|
instantiate_arg_reduce(int16, int16_t)
|
||||||
|
instantiate_arg_reduce(int32, int32_t)
|
||||||
|
instantiate_arg_reduce(int64, int64_t)
|
||||||
|
instantiate_arg_reduce(float16, half)
|
||||||
|
instantiate_arg_reduce(float32, float)
|
||||||
|
instantiate_arg_reduce(bfloat16, bfloat16_t)
|
392
mlx/backend/metal/kernels/bf16_math.h
Normal file
392
mlx/backend/metal/kernels/bf16_math.h
Normal file
@ -0,0 +1,392 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Metal math for bfloat16
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
Following the Metal Shading Language Specification (Metal 3.1)
|
||||||
|
|
||||||
|
"bfloat is an extended itypeing point type that only allows implicit conversion
|
||||||
|
to a type of greater itypeing point rank. While bfloat can be implicitly
|
||||||
|
converted to itype, it cannot be implicitly converted to half, and neither
|
||||||
|
itype nor half can be implicitly converted to bfloat."
|
||||||
|
|
||||||
|
Further, as far as I can tell, the stdlib math/simd functions are not defined
|
||||||
|
for bfloat and calling with an argument of type bfloat will result in that
|
||||||
|
argument getting implicitly converted to itype which then returns an output
|
||||||
|
that is (likely) a itype which cannot be implicitly converted into a bfloat
|
||||||
|
|
||||||
|
This leads to situations where
|
||||||
|
bfloat a = 5.0bf;
|
||||||
|
bfloat b = metal::abs(a); // this will throw an error since abs return itype
|
||||||
|
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
|
||||||
|
|
||||||
|
For the moment, I will be adding overloaded instantiations of the math
|
||||||
|
functions to accordingly automatically handle the casting
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype abs(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype acos(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype acosh(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype asin(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype asinh(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype atan(itype y_over_x) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype atan2(itype y, itype x) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype atanh(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype ceil(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype cos(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype cosh(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype cospi(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype divide(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype exp(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype exp10(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype exp2(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fabs(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fdim(itype x, itype y) { \
|
||||||
|
ctype t = static_cast<ctype>(x - y); \
|
||||||
|
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype floor(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fma(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fma( \
|
||||||
|
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fmax(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fmax3( \
|
||||||
|
static_cast<ctype>(x), \
|
||||||
|
static_cast<ctype>(y), \
|
||||||
|
static_cast<ctype>(z), \
|
||||||
|
mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fmedian3( \
|
||||||
|
static_cast<ctype>(x), \
|
||||||
|
static_cast<ctype>(y), \
|
||||||
|
static_cast<ctype>(z), \
|
||||||
|
mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fmin(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fmin3( \
|
||||||
|
static_cast<ctype>(x), \
|
||||||
|
static_cast<ctype>(y), \
|
||||||
|
static_cast<ctype>(z), \
|
||||||
|
mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fmod(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype fract(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype frexp(itype x, thread int& exp) { \
|
||||||
|
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype ldexp(itype x, int k) { \
|
||||||
|
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype log(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype log10(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype log2(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype max(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype max3(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fmax3( \
|
||||||
|
static_cast<ctype>(x), \
|
||||||
|
static_cast<ctype>(y), \
|
||||||
|
static_cast<ctype>(z), \
|
||||||
|
mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype median3(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fmedian3( \
|
||||||
|
static_cast<ctype>(x), \
|
||||||
|
static_cast<ctype>(y), \
|
||||||
|
static_cast<ctype>(z), \
|
||||||
|
mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype min(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype min3(itype x, itype y, itype z) { \
|
||||||
|
return static_cast<otype>(__metal_fmin3( \
|
||||||
|
static_cast<ctype>(x), \
|
||||||
|
static_cast<ctype>(y), \
|
||||||
|
static_cast<ctype>(z), \
|
||||||
|
mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype nextafter(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype pow(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype powr(itype x, itype y) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype rint(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype round(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype rsqrt(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype sin(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype sinh(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype sinpi(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype sqrt(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype tan(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype tanh(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype tanpi(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
|
||||||
|
} \
|
||||||
|
METAL_FUNC otype trunc(itype x) { \
|
||||||
|
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
|
instantiate_metal_math_funcs(
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
float,
|
||||||
|
__METAL_MAYBE_FAST_MATH__);
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
instantiate_metal_math_funcs(
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
float,
|
||||||
|
__METAL_FAST_MATH__);
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
namespace precise {
|
||||||
|
|
||||||
|
instantiate_metal_math_funcs(
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
float,
|
||||||
|
__METAL_PRECISE_MATH__);
|
||||||
|
|
||||||
|
} // namespace precise
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Metal simd for bfloat16
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#define instantiate_metal_simd_comm_funcs( \
|
||||||
|
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
||||||
|
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
||||||
|
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||||
|
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
||||||
|
itype data, itype filling_data, ushort delta) { \
|
||||||
|
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
||||||
|
itype_to_ctype(data), \
|
||||||
|
itype_to_ctype(filling_data), \
|
||||||
|
delta, \
|
||||||
|
__metal_get_simdgroup_size(ushort()))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
||||||
|
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
||||||
|
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||||
|
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
||||||
|
itype data, itype filling_data, ushort delta) { \
|
||||||
|
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
||||||
|
itype_to_ctype(data), \
|
||||||
|
itype_to_ctype(filling_data), \
|
||||||
|
delta, \
|
||||||
|
__metal_get_simdgroup_size(ushort()))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
|
||||||
|
return ctype_to_otype( \
|
||||||
|
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_max(itype data) { \
|
||||||
|
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_min(itype data) { \
|
||||||
|
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
|
||||||
|
return static_cast<otype>( \
|
||||||
|
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_product(itype data) { \
|
||||||
|
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_sum(itype data) { \
|
||||||
|
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC otype simd_xor(itype data) { \
|
||||||
|
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
|
||||||
|
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||||
|
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#define bfloat16_to_uint16(x) x.bits_
|
||||||
|
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
|
instantiate_metal_simd_comm_funcs(
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
uint16_t,
|
||||||
|
bfloat16_to_uint16,
|
||||||
|
uint16_to_bfloat16);
|
||||||
|
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
||||||
|
|
||||||
|
} // namespace metal
|
553
mlx/backend/metal/kernels/conv.metal
Normal file
553
mlx/backend/metal/kernels/conv.metal
Normal file
@ -0,0 +1,553 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/gemm/conv.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Slow and naive kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
const int BM, /* Threadgroup rows (in threads) */
|
||||||
|
const int BN, /* Threadgroup cols (in threads) */
|
||||||
|
const int TM, /* Thread rows (in elements) */
|
||||||
|
const int TN, /* Thread cols (in elements) */
|
||||||
|
const int BC = 16>
|
||||||
|
[[kernel]] void naive_conv_2d(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
const device T* wt [[buffer(1)]],
|
||||||
|
device T* out [[buffer(2)]],
|
||||||
|
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
(void)simd_gid;
|
||||||
|
(void)simd_lid;
|
||||||
|
|
||||||
|
out += tid.z * params.out_strides[0];
|
||||||
|
in += tid.z * params.in_strides[0];
|
||||||
|
|
||||||
|
int out_o = tid.y * BN * TN + lid.y * TN;
|
||||||
|
int out_hw = tid.x * BM * TM + lid.x * TM;
|
||||||
|
|
||||||
|
int out_h[TM];
|
||||||
|
int out_w[TN];
|
||||||
|
|
||||||
|
for(int m = 0; m < TM; ++m) {
|
||||||
|
int mm = (out_hw + m);
|
||||||
|
out_h[m] = mm / params.oS[1];
|
||||||
|
out_w[m] = mm % params.oS[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
T in_local[TM];
|
||||||
|
T wt_local[TN];
|
||||||
|
T out_local[TM * TN] = {T(0)};
|
||||||
|
|
||||||
|
for(int h = 0; h < params.wS[0]; ++h) {
|
||||||
|
for(int w = 0; w < params.wS[1]; ++w) {
|
||||||
|
for(int c = 0; c < params.C; ++c) {
|
||||||
|
|
||||||
|
// Local in
|
||||||
|
for(int m = 0; m < TM; m++) {
|
||||||
|
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.dil[0];
|
||||||
|
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.dil[1];
|
||||||
|
|
||||||
|
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
||||||
|
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load weight
|
||||||
|
for (int n = 0; n < TN; ++n) {
|
||||||
|
int o = out_o + n;
|
||||||
|
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] +
|
||||||
|
h * params.wt_strides[1] +
|
||||||
|
w * params.wt_strides[2] + c] : T(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulate
|
||||||
|
for(int m = 0; m < TM; ++m) {
|
||||||
|
for(int n = 0; n < TN; ++n) {
|
||||||
|
out_local[m * TN + n] += in_local[m] * wt_local[n];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int m = 0; m < TM; ++m) {
|
||||||
|
for(int n = 0; n < TN; ++n) {
|
||||||
|
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O)
|
||||||
|
out[out_h[m] * params.out_strides[1] +
|
||||||
|
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instantiations
|
||||||
|
|
||||||
|
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
||||||
|
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||||
|
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \
|
||||||
|
const device itype* in [[buffer(0)]], \
|
||||||
|
const device itype* wt [[buffer(1)]], \
|
||||||
|
device itype* out [[buffer(2)]], \
|
||||||
|
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_naive_conv_2d_blocks(name, itype) \
|
||||||
|
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
|
||||||
|
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
|
||||||
|
|
||||||
|
instantiate_naive_conv_2d_blocks(float32, float);
|
||||||
|
instantiate_naive_conv_2d_blocks(float16, half);
|
||||||
|
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Implicit gemm kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
const device T* wt [[buffer(1)]],
|
||||||
|
device T* out [[buffer(2)]],
|
||||||
|
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
using gemm_kernel = Conv2DImplicitGEMMKernel<T, BM, BN, BK, WM, WN, /*transpose_a*/ false, /*transpose_b*/ true>;
|
||||||
|
|
||||||
|
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
|
||||||
|
|
||||||
|
gemm_kernel::run(
|
||||||
|
in, wt, out,
|
||||||
|
params, tgp_memory,
|
||||||
|
tid, lid, simd_gid, simd_lid
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||||
|
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||||
|
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn>( \
|
||||||
|
const device itype* in [[buffer(0)]], \
|
||||||
|
const device itype* wt [[buffer(1)]], \
|
||||||
|
device itype* out [[buffer(2)]], \
|
||||||
|
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||||
|
instantiate_implicit_conv_2d(name, itype, 32, 32, 32, 2, 2) \
|
||||||
|
instantiate_implicit_conv_2d(name, itype, 32, 32, 16, 2, 2) \
|
||||||
|
instantiate_implicit_conv_2d(name, itype, 64, 64, 16, 2, 2)
|
||||||
|
|
||||||
|
instantiate_implicit_2d_blocks(float32, float);
|
||||||
|
instantiate_implicit_2d_blocks(float16, half);
|
||||||
|
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Winograd kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <int M, int R, int S>
|
||||||
|
struct WinogradTransforms {
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct WinogradTransforms<6, 3, 8> {
|
||||||
|
MLX_MTL_CONST int OUT_TILE_SIZE = 6;
|
||||||
|
MLX_MTL_CONST int FILTER_SIZE = 3;
|
||||||
|
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
|
||||||
|
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
|
||||||
|
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||||
|
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||||
|
{ 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
|
||||||
|
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
|
||||||
|
{ 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
|
||||||
|
{ 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
|
||||||
|
{ 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
|
||||||
|
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
|
||||||
|
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||||
|
};
|
||||||
|
|
||||||
|
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||||
|
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
|
||||||
|
{ 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
|
||||||
|
{ 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
|
||||||
|
{ 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
|
||||||
|
{ 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
|
||||||
|
{ 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
|
||||||
|
{ 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
|
||||||
|
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
|
||||||
|
};
|
||||||
|
|
||||||
|
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
|
||||||
|
{ 1.00, 0.00, 0.00},
|
||||||
|
{ -2.0/9.00, -2.0/9.00, -2.0/9.00},
|
||||||
|
{ -2.0/9.00, 2.0/9.00, -2.0/9.00},
|
||||||
|
{ 1.0/90.0, 1.0/45.0, 2.0/45.0},
|
||||||
|
{ 1.0/90.0, -1.0/45.0, 2.0/45.0},
|
||||||
|
{ 32.0/45.0, 16.0/45.0, 8.0/45.0},
|
||||||
|
{ 32.0/45.0, -16.0/45.0, 8.0/45.0},
|
||||||
|
{ 0.00, 0.00, 1.00},
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
|
||||||
|
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
|
||||||
|
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
int BC = 32,
|
||||||
|
int BO = 4,
|
||||||
|
int M = 6,
|
||||||
|
int R = 3>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
|
||||||
|
const device T* wt_in [[buffer(0)]],
|
||||||
|
device T* wt_out [[buffer(1)]],
|
||||||
|
const constant int& C [[buffer(2)]],
|
||||||
|
const constant int& O [[buffer(3)]],
|
||||||
|
uint tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
using WGT = WinogradTransforms<M, R, 8>;
|
||||||
|
|
||||||
|
// Get lane position in simdgroup
|
||||||
|
const short qid = simd_lane_id / 4;
|
||||||
|
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||||
|
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||||
|
|
||||||
|
// Initialize G matrix
|
||||||
|
simdgroup_matrix<T, 8, 8> G;
|
||||||
|
G.thread_elements()[0] = WGT::wt_transform[sm][sn];
|
||||||
|
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
|
||||||
|
|
||||||
|
// Initialize Gt matrix
|
||||||
|
simdgroup_matrix<T, 8, 8> Gt;
|
||||||
|
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
|
||||||
|
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
|
||||||
|
|
||||||
|
// Move to the correct output filter
|
||||||
|
size_t ko = BO * tid + simd_group_id;
|
||||||
|
wt_in += ko * R * R * C;
|
||||||
|
|
||||||
|
// wt_out is stored transposed (A x A x C x O)
|
||||||
|
short ohw_0 = sm * 8 + sn;
|
||||||
|
short ohw_1 = sm * 8 + sn + 1;
|
||||||
|
device T* wt_out_0 = wt_out + ohw_0 * C * O + ko;
|
||||||
|
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
|
||||||
|
|
||||||
|
// Prepare shared memory
|
||||||
|
threadgroup T Ws[BO][R][R][BC];
|
||||||
|
|
||||||
|
// Loop over C
|
||||||
|
for(int bc = 0; bc < C; bc += BC) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Read into shared memory
|
||||||
|
for(int kh = 0; kh < R; ++kh) {
|
||||||
|
for(int kw = 0; kw < R; ++kw) {
|
||||||
|
for(int kc = simd_lane_id; kc < BC; kc += 32) {
|
||||||
|
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Do transform and store the result
|
||||||
|
for(int c = 0; c < BC; ++c) {
|
||||||
|
simdgroup_matrix<T, 8, 8> g;
|
||||||
|
g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
||||||
|
g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
||||||
|
|
||||||
|
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
|
||||||
|
wt_out_0[c * O] = g_out.thread_elements()[0];
|
||||||
|
wt_out_1[c * O] = g_out.thread_elements()[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
wt_in += BC;
|
||||||
|
wt_out_0 += BC * O;
|
||||||
|
wt_out_1 += BC * O;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
|
||||||
|
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\
|
||||||
|
[[kernel]] void winograd_conv_2d_weight_transform<itype, bc>(\
|
||||||
|
const device itype* wt_in [[buffer(0)]],\
|
||||||
|
device itype* wt_out [[buffer(1)]],\
|
||||||
|
const constant int& C [[buffer(2)]],\
|
||||||
|
const constant int& O [[buffer(3)]],\
|
||||||
|
uint tid [[threadgroup_position_in_grid]],\
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
int BC,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
int M = 6,
|
||||||
|
int R = 3>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform(
|
||||||
|
const device T* inp_in [[buffer(0)]],
|
||||||
|
device T* inp_out [[buffer(1)]],
|
||||||
|
const constant MLXConvParams<2>& params [[buffer(2)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
using WGT = WinogradTransforms<M, R, 8>;
|
||||||
|
constexpr int A = WGT::IN_TILE_SIZE;
|
||||||
|
constexpr int N_SIMD_GROUPS = WM * WN;
|
||||||
|
|
||||||
|
// Get lane position in simdgroup
|
||||||
|
const short qid = simd_lane_id / 4;
|
||||||
|
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||||
|
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||||
|
|
||||||
|
// Initialize B matrix
|
||||||
|
simdgroup_matrix<T, 8, 8> B;
|
||||||
|
B.thread_elements()[0] = WGT::in_transform[sm][sn];
|
||||||
|
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
|
||||||
|
|
||||||
|
// Initialize Bt matrix
|
||||||
|
simdgroup_matrix<T, 8, 8> Bt;
|
||||||
|
Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
|
||||||
|
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
|
||||||
|
|
||||||
|
// Resolve input tile
|
||||||
|
constexpr int TH = (A / WM);
|
||||||
|
constexpr int TW = (A / WN);
|
||||||
|
int kh = TH * (simd_group_id / WN);
|
||||||
|
int kw = TW * (simd_group_id % WN);
|
||||||
|
int bh = M * tid.y + kh;
|
||||||
|
int bw = M * tid.x + kw;
|
||||||
|
|
||||||
|
// Move to the correct input tile
|
||||||
|
inp_in += tid.z * params.in_strides[0]
|
||||||
|
+ bh * params.in_strides[1]
|
||||||
|
+ bw * params.in_strides[2];
|
||||||
|
|
||||||
|
// Pre compute strides
|
||||||
|
int jump_in[TH][TW];
|
||||||
|
|
||||||
|
for(int h = 0; h < TH; h++) {
|
||||||
|
for(int w = 0; w < TW; w++) {
|
||||||
|
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inp_out is stored interleaved (A x A x tiles x C)
|
||||||
|
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
||||||
|
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||||
|
size_t ohw_0 = sm * 8 + sn;
|
||||||
|
size_t ohw_1 = sm * 8 + sn + 1;
|
||||||
|
device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
|
||||||
|
device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
|
||||||
|
|
||||||
|
// Prepare shared memory
|
||||||
|
threadgroup T Is[A][A][BC];
|
||||||
|
|
||||||
|
// Loop over C
|
||||||
|
for(int bc = 0; bc < params.C; bc += BC) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Read into shared memory
|
||||||
|
for(int h = 0; h < TH; h++) {
|
||||||
|
for(int w = 0; w < TW; w++) {
|
||||||
|
const device T* in_ptr = inp_in + jump_in[h][w];
|
||||||
|
for(int c = simd_lane_id; c < BC; c += 32) {
|
||||||
|
Is[kh + h][kw + w][c] = in_ptr[c];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Do transform and store the result
|
||||||
|
for(int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
|
||||||
|
simdgroup_matrix<T, 8, 8> I;
|
||||||
|
I.thread_elements()[0] = Is[sm][sn][c];
|
||||||
|
I.thread_elements()[1] = Is[sm][sn + 1][c];
|
||||||
|
|
||||||
|
simdgroup_matrix<T, 8, 8> I_out = (Bt * I) * B;
|
||||||
|
inp_out_0[c] = I_out.thread_elements()[0];
|
||||||
|
inp_out_1[c] = I_out.thread_elements()[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
inp_in += BC;
|
||||||
|
inp_out_0 += BC;
|
||||||
|
inp_out_1 += BC;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
|
||||||
|
template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\
|
||||||
|
[[kernel]] void winograd_conv_2d_input_transform<itype, bc, 2, 2>(\
|
||||||
|
const device itype* inp_in [[buffer(0)]],\
|
||||||
|
device itype* inp_out [[buffer(1)]],\
|
||||||
|
const constant MLXConvParams<2>& params [[buffer(2)]],\
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],\
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],\
|
||||||
|
uint3 tgp_per_grid [[threadgroups_per_grid]],\
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
int BO,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
int M = 6,
|
||||||
|
int R = 3>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform(
|
||||||
|
const device T* out_in [[buffer(0)]],
|
||||||
|
device T* out_out [[buffer(1)]],
|
||||||
|
const constant MLXConvParams<2>& params [[buffer(2)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint3 tgp_per_grid [[threadgroups_per_grid]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
using WGT = WinogradTransforms<M, R, 8>;
|
||||||
|
constexpr int N_SIMD_GROUPS = WM * WN;
|
||||||
|
|
||||||
|
// Get lane position in simdgroup
|
||||||
|
const short qid = simd_lane_id / 4;
|
||||||
|
const short sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||||
|
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||||
|
|
||||||
|
// Initialize A matrix
|
||||||
|
simdgroup_matrix<T, 8, 8> B;
|
||||||
|
B.thread_elements()[0] = WGT::out_transform[sm][sn];
|
||||||
|
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
|
||||||
|
|
||||||
|
// Initialize At matrix
|
||||||
|
simdgroup_matrix<T, 8, 8> Bt;
|
||||||
|
Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
|
||||||
|
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
|
||||||
|
|
||||||
|
// Out_in comes in shape (A x A x tiles x O)
|
||||||
|
// We do transform and then write out to out_out in shape (N, H, W, O)
|
||||||
|
|
||||||
|
// Resolve output tile
|
||||||
|
constexpr int TH = (M / WM);
|
||||||
|
constexpr int TW = (M / WN);
|
||||||
|
int kh = TH * (simd_group_id / WN);
|
||||||
|
int kw = TW * (simd_group_id % WN);
|
||||||
|
int bh = M * tid.y + kh;
|
||||||
|
int bw = M * tid.x + kw;
|
||||||
|
|
||||||
|
// Move to the correct input tile
|
||||||
|
out_out += tid.z * params.out_strides[0]
|
||||||
|
+ bh * params.out_strides[1]
|
||||||
|
+ bw * params.out_strides[2];
|
||||||
|
|
||||||
|
// Pre compute strides
|
||||||
|
int jump_in[TH][TW];
|
||||||
|
|
||||||
|
for(int h = 0; h < TH; h++) {
|
||||||
|
for(int w = 0; w < TW; w++) {
|
||||||
|
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
|
||||||
|
jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// out_in is stored interleaved (A x A x tiles x O)
|
||||||
|
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
|
||||||
|
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
|
||||||
|
size_t ohw_0 = sm * 8 + sn;
|
||||||
|
size_t ohw_1 = sm * 8 + sn + 1;
|
||||||
|
const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
|
||||||
|
const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
|
||||||
|
|
||||||
|
// Prepare shared memory
|
||||||
|
threadgroup T Os[M][M][BO];
|
||||||
|
|
||||||
|
// Loop over O
|
||||||
|
for(int bo = 0; bo < params.O; bo += BO) {
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Do transform and store the result
|
||||||
|
for(int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
||||||
|
simdgroup_matrix<T, 8, 8> O_mat;
|
||||||
|
O_mat.thread_elements()[0] = out_in_0[c];
|
||||||
|
O_mat.thread_elements()[1] = out_in_1[c];
|
||||||
|
|
||||||
|
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
|
||||||
|
if((sm < M) && (sn < M)) {
|
||||||
|
Os[sm][sn][c] = O_out.thread_elements()[0];
|
||||||
|
}
|
||||||
|
if((sm < M) && ((sn + 1) < M)) {
|
||||||
|
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Read out from shared memory
|
||||||
|
for(int h = 0; h < TH; h++) {
|
||||||
|
for(int w = 0; w < TW; w++) {
|
||||||
|
if(jump_in[h][w] >= 0) {
|
||||||
|
device T* out_ptr = out_out + jump_in[h][w];
|
||||||
|
for(int c = simd_lane_id; c < BO; c += 32) {
|
||||||
|
out_ptr[c] = Os[kh + h][kw + w][c];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out_out += BO;
|
||||||
|
out_in_0 += BO;
|
||||||
|
out_in_1 += BO;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
|
||||||
|
template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\
|
||||||
|
[[kernel]] void winograd_conv_2d_output_transform<itype, bo, 2, 2>(\
|
||||||
|
const device itype* out_in [[buffer(0)]],\
|
||||||
|
device itype* out_out [[buffer(1)]],\
|
||||||
|
const constant MLXConvParams<2>& params [[buffer(2)]],\
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],\
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],\
|
||||||
|
uint3 tgp_per_grid [[threadgroups_per_grid]],\
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_winograd_conv_2d(name, itype) \
|
||||||
|
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
|
||||||
|
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
|
||||||
|
instantiate_winograd_conv_2d_output_transform(name, itype, 32)
|
||||||
|
|
||||||
|
instantiate_winograd_conv_2d(float32, float);
|
||||||
|
instantiate_winograd_conv_2d(float16, half);
|
269
mlx/backend/metal/kernels/copy.metal
Normal file
269
mlx/backend/metal/kernels/copy.metal
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_s(
|
||||||
|
device const T* src,
|
||||||
|
device U* dst,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
dst[index] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_v(
|
||||||
|
device const T* src,
|
||||||
|
device U* dst,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
dst[index] = static_cast<U>(src[index]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_g_nd1(
|
||||||
|
device const T* src,
|
||||||
|
device U* dst,
|
||||||
|
constant const size_t& src_stride,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||||
|
dst[index] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_g_nd2(
|
||||||
|
device const T* src,
|
||||||
|
device U* dst,
|
||||||
|
constant const size_t src_strides[2],
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||||
|
size_t dst_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_g_nd3(
|
||||||
|
device const T* src,
|
||||||
|
device U* dst,
|
||||||
|
constant const size_t src_strides[3],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||||
|
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, int DIM>
|
||||||
|
[[kernel]] void copy_g_nd(
|
||||||
|
device const T* src,
|
||||||
|
device U* dst,
|
||||||
|
constant const int src_shape[DIM],
|
||||||
|
constant const size_t src_strides[DIM],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||||
|
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_g(
|
||||||
|
device const T* src,
|
||||||
|
device U* dst,
|
||||||
|
constant const int* src_shape,
|
||||||
|
constant const size_t* src_strides,
|
||||||
|
constant const int& ndim,
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||||
|
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_gg_nd1(
|
||||||
|
device const T* src,
|
||||||
|
device U* dst,
|
||||||
|
constant const size_t& src_stride,
|
||||||
|
constant const size_t& dst_stride,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||||
|
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_gg_nd2(
|
||||||
|
device const T* src,
|
||||||
|
device U* dst,
|
||||||
|
constant const size_t src_strides[2],
|
||||||
|
constant const size_t dst_strides[2],
|
||||||
|
uint2 index [[thread_position_in_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||||
|
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_gg_nd3(
|
||||||
|
device const T* src,
|
||||||
|
device U* dst,
|
||||||
|
constant const size_t src_strides[3],
|
||||||
|
constant const size_t dst_strides[3],
|
||||||
|
uint3 index [[thread_position_in_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||||
|
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, int DIM>
|
||||||
|
[[kernel]] void copy_gg_nd(
|
||||||
|
device const T* src,
|
||||||
|
device U* dst,
|
||||||
|
constant const int src_shape[DIM],
|
||||||
|
constant const size_t src_strides[DIM],
|
||||||
|
constant const size_t dst_strides[DIM],
|
||||||
|
uint3 index [[thread_position_in_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||||
|
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_gg(
|
||||||
|
device const T* src,
|
||||||
|
device U* dst,
|
||||||
|
constant const int* src_shape,
|
||||||
|
constant const size_t* src_strides,
|
||||||
|
constant const size_t* dst_strides,
|
||||||
|
constant const int& ndim,
|
||||||
|
uint3 index [[thread_position_in_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||||
|
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_copy(name, itype, otype, ctype) \
|
||||||
|
template [[host_name(name)]] \
|
||||||
|
[[kernel]] void copy_##ctype<itype, otype>( \
|
||||||
|
device const itype* src, \
|
||||||
|
device otype* dst, \
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||||
|
template [[host_name(name "_" #dims)]] \
|
||||||
|
[[kernel]] void copy_g_nd<itype, otype, dims>( \
|
||||||
|
device const itype* src, \
|
||||||
|
device otype* dst, \
|
||||||
|
constant const int src_shape[dims], \
|
||||||
|
constant const size_t src_strides[dims], \
|
||||||
|
uint3 index [[thread_position_in_grid]], \
|
||||||
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
|
template [[host_name("g" name "_" #dims)]] \
|
||||||
|
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
|
||||||
|
device const itype* src, \
|
||||||
|
device otype* dst, \
|
||||||
|
constant const int src_shape[dims], \
|
||||||
|
constant const size_t src_strides[dims], \
|
||||||
|
constant const size_t dst_strides[dims], \
|
||||||
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
|
||||||
|
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||||
|
template [[host_name(name "_1")]] \
|
||||||
|
[[kernel]] void copy_g_nd1<itype, otype>( \
|
||||||
|
device const itype* src, \
|
||||||
|
device otype* dst, \
|
||||||
|
constant const size_t& src_stride, \
|
||||||
|
uint index [[thread_position_in_grid]]); \
|
||||||
|
template [[host_name(name "_2")]] \
|
||||||
|
[[kernel]] void copy_g_nd2<itype, otype>( \
|
||||||
|
device const itype* src, \
|
||||||
|
device otype* dst, \
|
||||||
|
constant const size_t src_strides[2], \
|
||||||
|
uint2 index [[thread_position_in_grid]], \
|
||||||
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
|
template [[host_name(name "_3")]] \
|
||||||
|
[[kernel]] void copy_g_nd3<itype, otype>( \
|
||||||
|
device const itype* src, \
|
||||||
|
device otype* dst, \
|
||||||
|
constant const size_t src_strides[3], \
|
||||||
|
uint3 index [[thread_position_in_grid]], \
|
||||||
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
|
template [[host_name("g" name "_1")]] \
|
||||||
|
[[kernel]] void copy_gg_nd1<itype, otype>( \
|
||||||
|
device const itype* src, \
|
||||||
|
device otype* dst, \
|
||||||
|
constant const size_t& src_stride, \
|
||||||
|
constant const size_t& dst_stride, \
|
||||||
|
uint index [[thread_position_in_grid]]); \
|
||||||
|
template [[host_name("g" name "_2")]] \
|
||||||
|
[[kernel]] void copy_gg_nd2<itype, otype>( \
|
||||||
|
device const itype* src, \
|
||||||
|
device otype* dst, \
|
||||||
|
constant const size_t src_strides[2], \
|
||||||
|
constant const size_t dst_strides[2], \
|
||||||
|
uint2 index [[thread_position_in_grid]]); \
|
||||||
|
template [[host_name("g" name "_3")]] \
|
||||||
|
[[kernel]] void copy_gg_nd3<itype, otype>( \
|
||||||
|
device const itype* src, \
|
||||||
|
device otype* dst, \
|
||||||
|
constant const size_t src_strides[3], \
|
||||||
|
constant const size_t dst_strides[3], \
|
||||||
|
uint3 index [[thread_position_in_grid]]); \
|
||||||
|
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||||
|
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||||
|
|
||||||
|
|
||||||
|
#define instantiate_copy_g(name, itype, otype) \
|
||||||
|
template [[host_name(name)]] \
|
||||||
|
[[kernel]] void copy_g<itype, otype>( \
|
||||||
|
device const itype* src, \
|
||||||
|
device otype* dst, \
|
||||||
|
constant const int* src_shape, \
|
||||||
|
constant const size_t* src_strides, \
|
||||||
|
constant const int& ndim, \
|
||||||
|
uint3 index [[thread_position_in_grid]], \
|
||||||
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
|
template [[host_name("g" name)]] \
|
||||||
|
[[kernel]] void copy_gg<itype, otype>( \
|
||||||
|
device const itype* src, \
|
||||||
|
device otype* dst, \
|
||||||
|
constant const int* src_shape, \
|
||||||
|
constant const size_t* src_strides, \
|
||||||
|
constant const size_t* dst_strides, \
|
||||||
|
constant const int& ndim, \
|
||||||
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
#define instantiate_copy_all(tname, itype, otype) \
|
||||||
|
instantiate_copy("scopy" #tname, itype, otype, s) \
|
||||||
|
instantiate_copy("vcopy" #tname, itype, otype, v) \
|
||||||
|
instantiate_copy_g("gcopy" #tname, itype, otype) \
|
||||||
|
instantiate_copy_g_nd("gcopy" #tname, itype, otype)
|
||||||
|
|
||||||
|
#define instantiate_copy_itype(itname, itype) \
|
||||||
|
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||||
|
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||||
|
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
|
||||||
|
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
|
||||||
|
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
|
||||||
|
instantiate_copy_all(itname ##int8, itype, int8_t) \
|
||||||
|
instantiate_copy_all(itname ##int16, itype, int16_t) \
|
||||||
|
instantiate_copy_all(itname ##int32, itype, int32_t) \
|
||||||
|
instantiate_copy_all(itname ##int64, itype, int64_t) \
|
||||||
|
instantiate_copy_all(itname ##float16, itype, half) \
|
||||||
|
instantiate_copy_all(itname ##float32, itype, float) \
|
||||||
|
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
|
||||||
|
instantiate_copy_all(itname ##complex64, itype, complex64_t)
|
||||||
|
|
||||||
|
instantiate_copy_itype(bool_, bool)
|
||||||
|
instantiate_copy_itype(uint8, uint8_t)
|
||||||
|
instantiate_copy_itype(uint16, uint16_t)
|
||||||
|
instantiate_copy_itype(uint32, uint32_t)
|
||||||
|
instantiate_copy_itype(uint64, uint64_t)
|
||||||
|
instantiate_copy_itype(int8, int8_t)
|
||||||
|
instantiate_copy_itype(int16, int16_t)
|
||||||
|
instantiate_copy_itype(int32, int32_t)
|
||||||
|
instantiate_copy_itype(int64, int64_t)
|
||||||
|
instantiate_copy_itype(float16, half)
|
||||||
|
instantiate_copy_itype(float32, float)
|
||||||
|
instantiate_copy_itype(bfloat16, bfloat16_t)
|
||||||
|
instantiate_copy_itype(complex64, complex64_t)
|
68
mlx/backend/metal/kernels/erf.h
Normal file
68
mlx/backend/metal/kernels/erf.h
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_math>
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Approximation to the error function.
|
||||||
|
* Based on code from:
|
||||||
|
* https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
|
||||||
|
*/
|
||||||
|
float erf(float a) {
|
||||||
|
float r, s, t, u;
|
||||||
|
t = metal::abs(a);
|
||||||
|
s = a * a;
|
||||||
|
if (t > 0.927734375f) {
|
||||||
|
// maximum error 0.99527 ulp
|
||||||
|
r = metal::fma(
|
||||||
|
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
|
||||||
|
u = metal::fma(
|
||||||
|
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
|
||||||
|
r = metal::fma(r, s, u);
|
||||||
|
r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
|
||||||
|
r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
|
||||||
|
r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
|
||||||
|
r = metal::fma(r, t, -t);
|
||||||
|
// TODO, replace with expm1 when implemented
|
||||||
|
r = 1.0f - metal::exp(r);
|
||||||
|
r = metal::copysign(r, a);
|
||||||
|
} else {
|
||||||
|
// maximum error 0.98929 ulp
|
||||||
|
r = -5.96761703e-4f; // -0x1.38e000p-11
|
||||||
|
r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
|
||||||
|
r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
|
||||||
|
r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
|
||||||
|
r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
|
||||||
|
r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
|
||||||
|
r = metal::fma(r, a, a);
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
float erfinv(float a) {
|
||||||
|
auto t = metal::fma(a, 0.0f - a, 1.0f);
|
||||||
|
t = metal::log(t);
|
||||||
|
float p;
|
||||||
|
if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
||||||
|
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||||
|
p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||||
|
p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||||
|
p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||||
|
p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||||
|
p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||||
|
p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||||
|
p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||||
|
p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||||
|
} else { // maximum ulp error = 2.35002
|
||||||
|
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||||
|
p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||||
|
p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||||
|
p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||||
|
p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||||
|
p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||||
|
p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||||
|
p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||||
|
p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||||
|
p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||||
|
}
|
||||||
|
return a * p;
|
||||||
|
}
|
91
mlx/backend/metal/kernels/gemm.metal
Normal file
91
mlx/backend/metal/kernels/gemm.metal
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/gemm/gemm.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool MN_aligned,
|
||||||
|
bool K_aligned>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
||||||
|
const device T *A [[buffer(0)]],
|
||||||
|
const device T *B [[buffer(1)]],
|
||||||
|
device T *C [[buffer(2)]],
|
||||||
|
const constant int &M [[buffer(3)]],
|
||||||
|
const constant int &N [[buffer(4)]],
|
||||||
|
const constant int &K [[buffer(5)]],
|
||||||
|
const constant int &batch_stride_a [[buffer(6)]],
|
||||||
|
const constant int &batch_stride_b [[buffer(7)]],
|
||||||
|
const constant int &batch_stride_c [[buffer(8)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
using gemm_kernel = GEMMKernel<T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||||
|
|
||||||
|
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
|
||||||
|
|
||||||
|
gemm_kernel::run(
|
||||||
|
A, B, C,
|
||||||
|
M, N, K,
|
||||||
|
batch_stride_a, batch_stride_b, batch_stride_c,
|
||||||
|
tgp_memory,
|
||||||
|
simd_lane_id, simd_group_id, tid, lid
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernel initializations
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
|
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||||
|
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||||
|
const device itype *A [[buffer(0)]], \
|
||||||
|
const device itype *B [[buffer(1)]], \
|
||||||
|
device itype *C [[buffer(2)]], \
|
||||||
|
const constant int &M [[buffer(3)]], \
|
||||||
|
const constant int &N [[buffer(4)]], \
|
||||||
|
const constant int &K [[buffer(5)]], \
|
||||||
|
const constant int &batch_stride_a [[buffer(6)]], \
|
||||||
|
const constant int &batch_stride_b [[buffer(7)]], \
|
||||||
|
const constant int &batch_stride_c [[buffer(8)]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||||
|
|
||||||
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2)
|
||||||
|
|
||||||
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
|
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||||
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
|
|
||||||
|
// TODO: Accumulation in different type
|
99
mlx/backend/metal/kernels/random.metal
Normal file
99
mlx/backend/metal/kernels/random.metal
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
static constexpr constant uint32_t rotations[2][4] = {
|
||||||
|
{13, 15, 26, 6},
|
||||||
|
{17, 29, 16, 24}
|
||||||
|
};
|
||||||
|
|
||||||
|
union rbits {
|
||||||
|
uint2 val;
|
||||||
|
uchar4 bytes[2];
|
||||||
|
};
|
||||||
|
|
||||||
|
rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||||
|
|
||||||
|
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
|
||||||
|
|
||||||
|
rbits v;
|
||||||
|
v.val.x = count.x + ks[0];
|
||||||
|
v.val.y = count.y + ks[1];
|
||||||
|
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
for (auto r : rotations[i % 2]) {
|
||||||
|
v.val.x += v.val.y;
|
||||||
|
v.val.y = (v.val.y << r) | (v.val.y >> (32 - r));
|
||||||
|
v.val.y ^= v.val.x;
|
||||||
|
}
|
||||||
|
v.val.x += ks[(i + 1) % 3];
|
||||||
|
v.val.y += ks[(i + 2) % 3] + i + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[kernel]] void rbitsc(
|
||||||
|
device const uint32_t* keys,
|
||||||
|
device char* out,
|
||||||
|
device const bool& odd,
|
||||||
|
device const uint& bytes_per_key,
|
||||||
|
uint2 grid_dim [[threads_per_grid]],
|
||||||
|
uint2 index [[thread_position_in_grid]]) {
|
||||||
|
auto kidx = 2 * index.x;
|
||||||
|
auto key = uint2(keys[kidx], keys[kidx + 1]);
|
||||||
|
auto half_size = grid_dim.y - odd;
|
||||||
|
out += index.x * bytes_per_key;
|
||||||
|
bool drop_last = odd && (index.y == half_size);
|
||||||
|
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
|
||||||
|
auto bits = threefry2x32_hash(key, count);
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
out[4 * count.x + i] = bits.bytes[0][i];
|
||||||
|
}
|
||||||
|
if (!drop_last) {
|
||||||
|
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||||
|
int edge_bytes = (bytes_per_key % 4);
|
||||||
|
for (int i = 0; i < edge_bytes; ++i) {
|
||||||
|
out[4 * count.y + i] = bits.bytes[1][i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
out[4 * count.y + i] = bits.bytes[1][i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[kernel]] void rbits(
|
||||||
|
device const uint32_t* keys,
|
||||||
|
device char* out,
|
||||||
|
device const bool& odd,
|
||||||
|
device const uint& bytes_per_key,
|
||||||
|
device const int& ndim,
|
||||||
|
device const int* key_shape,
|
||||||
|
device const size_t* key_strides,
|
||||||
|
uint2 grid_dim [[threads_per_grid]],
|
||||||
|
uint2 index [[thread_position_in_grid]]) {
|
||||||
|
auto kidx = 2 * index.x;
|
||||||
|
auto k1_elem = elem_to_loc(kidx, key_shape, key_strides, ndim);
|
||||||
|
auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim);
|
||||||
|
auto key = uint2(keys[k1_elem], keys[k2_elem]);
|
||||||
|
auto half_size = grid_dim.y - odd;
|
||||||
|
out += index.x * bytes_per_key;
|
||||||
|
bool drop_last = odd && (index.y == half_size);
|
||||||
|
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
|
||||||
|
auto bits = threefry2x32_hash(key, count);
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
out[4 * count.x + i] = bits.bytes[0][i];
|
||||||
|
}
|
||||||
|
if (!drop_last) {
|
||||||
|
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
|
||||||
|
int edge_bytes = (bytes_per_key % 4);
|
||||||
|
for (int i = 0; i < edge_bytes; ++i) {
|
||||||
|
out[4 * count.y + i] = bits.bytes[1][i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
out[4 * count.y + i] = bits.bytes[1][i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
536
mlx/backend/metal/kernels/reduce.metal
Normal file
536
mlx/backend/metal/kernels/reduce.metal
Normal file
@ -0,0 +1,536 @@
|
|||||||
|
#include <metal_atomic>
|
||||||
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/kernels/reduce.h"
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
static constant uint8_t simd_size = 32;
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
[[kernel]] void init_reduce(
|
||||||
|
device T *out [[buffer(0)]],
|
||||||
|
uint tid [[thread_position_in_grid]]) {
|
||||||
|
out[tid] = Op::init;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_init_reduce(name, otype, op) \
|
||||||
|
template [[host_name("i" #name)]] \
|
||||||
|
[[kernel]] void init_reduce<otype, op>( \
|
||||||
|
device otype *out [[buffer(1)]], \
|
||||||
|
uint tid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// All reduce
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||||
|
[[kernel]] void all_reduce(
|
||||||
|
const device T *in [[buffer(0)]],
|
||||||
|
device mlx_atomic<U> *out [[buffer(1)]],
|
||||||
|
const device size_t& in_size [[buffer(2)]],
|
||||||
|
uint gid [[thread_position_in_grid]],
|
||||||
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
|
uint grid_size [[threads_per_grid]],
|
||||||
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
// NB: this kernel assumes threads_per_threadgroup is at most
|
||||||
|
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||||
|
// complete the reduction in two steps of simd-level reductions.
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
threadgroup U local_vals[simd_size];
|
||||||
|
|
||||||
|
U total_val = Op::init;
|
||||||
|
|
||||||
|
in += gid * N_READS;
|
||||||
|
|
||||||
|
int r = 0;
|
||||||
|
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
|
||||||
|
U vals[N_READS] = {op.init};
|
||||||
|
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
vals[i] = static_cast<U>(in[i]);
|
||||||
|
}
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
total_val = op(vals[i], total_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
in += grid_size * N_READS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sepate case for the last set as we close the reduction size
|
||||||
|
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||||
|
if (curr_idx < in_size) {
|
||||||
|
int max_reads = in_size - curr_idx;
|
||||||
|
T vals[N_READS];
|
||||||
|
|
||||||
|
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
|
||||||
|
idx = idx < max_reads ? idx : max_reads - 1;
|
||||||
|
vals[i] = in[idx];
|
||||||
|
}
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
U val = i < max_reads ? vals[i] : Op::init;
|
||||||
|
total_val = op(static_cast<U>(val), total_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduction within simd group
|
||||||
|
total_val = op.simd_reduce(total_val);
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_vals[simd_group_id] = total_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduction within thread group
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||||
|
total_val = op.simd_reduce(total_val);
|
||||||
|
|
||||||
|
// Reduction across threadgroups
|
||||||
|
if (lid == 0) {
|
||||||
|
op.atomic_update(out, total_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||||
|
template [[host_name("all_reduce_" #name)]] \
|
||||||
|
[[kernel]] void all_reduce<itype, otype, op>( \
|
||||||
|
const device itype *in [[buffer(0)]], \
|
||||||
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||||
|
const device size_t& in_size [[buffer(2)]], \
|
||||||
|
uint gid [[thread_position_in_grid]], \
|
||||||
|
uint lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint grid_size [[threads_per_grid]], \
|
||||||
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// General reduce
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void general_reduce(
|
||||||
|
const device T *in [[buffer(0)]],
|
||||||
|
device mlx_atomic<U> *out [[buffer(1)]],
|
||||||
|
const device int *in_shape [[buffer(2)]],
|
||||||
|
const device size_t *in_strides [[buffer(3)]],
|
||||||
|
const device size_t *out_strides [[buffer(4)]],
|
||||||
|
const device size_t& ndim [[buffer(5)]],
|
||||||
|
uint gid [[thread_position_in_grid]]) {
|
||||||
|
Op op;
|
||||||
|
auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim);
|
||||||
|
auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim);
|
||||||
|
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int NDIM>
|
||||||
|
[[kernel]] void general_reduce(
|
||||||
|
const device T *in [[buffer(0)]],
|
||||||
|
device mlx_atomic<U> *out [[buffer(1)]],
|
||||||
|
const device int *in_shape [[buffer(2)]],
|
||||||
|
const device size_t *in_strides [[buffer(3)]],
|
||||||
|
const device size_t *out_strides [[buffer(4)]],
|
||||||
|
uint gid [[thread_position_in_grid]]) {
|
||||||
|
Op op;
|
||||||
|
auto in_idx = elem_to_loc_nd<NDIM>(gid, in_shape, in_strides);
|
||||||
|
auto out_idx = elem_to_loc_nd<NDIM>(gid, in_shape, out_strides);
|
||||||
|
op.atomic_update(out, static_cast<U>(in[in_idx]), out_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_general_reduce_helper(name, itype, otype, op) \
|
||||||
|
template [[host_name("general_reduce_" #name)]] \
|
||||||
|
[[kernel]] void general_reduce<itype, otype, op>( \
|
||||||
|
const device itype *in [[buffer(0)]], \
|
||||||
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||||
|
const device int *in_shape [[buffer(2)]], \
|
||||||
|
const device size_t *in_strides [[buffer(3)]], \
|
||||||
|
const device size_t *out_strides [[buffer(4)]], \
|
||||||
|
const device size_t& ndim [[buffer(5)]], \
|
||||||
|
uint gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \
|
||||||
|
template [[host_name("general_reduce_" #name "_dim_" #n)]] \
|
||||||
|
[[kernel]] void general_reduce<itype, otype, op, n>( \
|
||||||
|
const device itype *in [[buffer(0)]], \
|
||||||
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||||
|
const device int *in_shape [[buffer(2)]], \
|
||||||
|
const device size_t *in_strides [[buffer(3)]], \
|
||||||
|
const device size_t *out_strides [[buffer(4)]], \
|
||||||
|
uint gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
#define instantiate_general_reduce(name, itype, otype, op) \
|
||||||
|
instantiate_general_reduce_helper(name, itype, otype, op) \
|
||||||
|
instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \
|
||||||
|
instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \
|
||||||
|
instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \
|
||||||
|
instantiate_general_reduce_helper_nd(name, itype, otype, op, 4)
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Row atomics
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||||
|
[[kernel]] void row_reduce(
|
||||||
|
const device T *in [[buffer(0)]],
|
||||||
|
device U *out [[buffer(1)]],
|
||||||
|
const device size_t& reduction_size [[buffer(2)]],
|
||||||
|
uint lid [[thread_position_in_threadgroup]],
|
||||||
|
uint lsize [[threads_per_threadgroup]],
|
||||||
|
uint tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
// Each threadgroup handles 1 reduction
|
||||||
|
in += tid * reduction_size + lid * N_READS;
|
||||||
|
|
||||||
|
// The reduction is accumulated here
|
||||||
|
U total_val = Op::init;
|
||||||
|
threadgroup U local_vals[simd_size];
|
||||||
|
|
||||||
|
// Loop over the reduction size within thread group
|
||||||
|
int r = 0;
|
||||||
|
for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) {
|
||||||
|
T vals[N_READS];
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
vals[i] = in[i];
|
||||||
|
}
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
in += lsize * N_READS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sepate case for the last set as we close the reduction size
|
||||||
|
size_t reduction_index = (lid + (size_t)lsize * r) * N_READS;
|
||||||
|
if(reduction_index < reduction_size) {
|
||||||
|
int max_reads = reduction_size - reduction_index;
|
||||||
|
|
||||||
|
T vals[N_READS];
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
int idx = min(i, max_reads - 1);
|
||||||
|
vals[i] = static_cast<U>(in[idx]);
|
||||||
|
}
|
||||||
|
for(int i = 0; i < N_READS; i++) {
|
||||||
|
T val = i < max_reads ? vals[i] : Op::init;
|
||||||
|
total_val = op(static_cast<U>(val), total_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
total_val = op.simd_reduce(total_val);
|
||||||
|
|
||||||
|
// Prepare next level
|
||||||
|
if (simd_lane_id == 0) {
|
||||||
|
local_vals[simd_group_id] = total_val;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Reduction within thread group
|
||||||
|
// Only needed if multiple simd groups
|
||||||
|
if(reduction_size > simd_size) {
|
||||||
|
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||||
|
total_val = op.simd_reduce(total_val);
|
||||||
|
}
|
||||||
|
// Update output
|
||||||
|
if (lid == 0) {
|
||||||
|
out[tid] = total_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_row_reduce(name, itype, otype, op) \
|
||||||
|
template [[host_name("row_reduce_" #name)]] \
|
||||||
|
[[kernel]] void row_reduce<itype, otype, op>( \
|
||||||
|
const device itype *in [[buffer(0)]], \
|
||||||
|
device otype *out [[buffer(1)]], \
|
||||||
|
const device size_t& reduction_size [[buffer(2)]], \
|
||||||
|
uint lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint lsize [[threads_per_threadgroup]], \
|
||||||
|
uint tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Column reduce
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
|
inline void _contiguous_strided_reduce(
|
||||||
|
const device T *in,
|
||||||
|
device mlx_atomic<U> *out,
|
||||||
|
threadgroup U *local_data,
|
||||||
|
uint in_idx,
|
||||||
|
uint out_idx,
|
||||||
|
uint reduction_size,
|
||||||
|
uint reduction_stride,
|
||||||
|
uint2 tid,
|
||||||
|
uint2 lid,
|
||||||
|
uint2 lsize) {
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
T local_vals[N_READS];
|
||||||
|
|
||||||
|
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
|
||||||
|
|
||||||
|
for(uint r = 0; r < N_READS; r++) {
|
||||||
|
uint offset = base_offset + r;
|
||||||
|
offset = offset < reduction_size ? offset : reduction_size - 1;
|
||||||
|
local_vals[r] = in[in_idx + offset * reduction_stride];
|
||||||
|
}
|
||||||
|
|
||||||
|
U total_val = Op::init;
|
||||||
|
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
|
||||||
|
total_val = op(static_cast<U>(total_val), local_vals[r]);
|
||||||
|
}
|
||||||
|
local_data[lsize.y * lid.x + lid.y] = total_val;
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if(lid.y == 0) {
|
||||||
|
U val = op.init;
|
||||||
|
|
||||||
|
for(uint i = 0; i < lsize.y; i++) {
|
||||||
|
val = op(val, local_data[lsize.y * lid.x + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
op.atomic_update(out, val, out_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
|
[[kernel]] void col_reduce(
|
||||||
|
const device T *in [[buffer(0)]],
|
||||||
|
device mlx_atomic<U> *out [[buffer(1)]],
|
||||||
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
|
const constant size_t& reduction_stride [[buffer(3)]],
|
||||||
|
const constant size_t& out_size [[buffer(4)]],
|
||||||
|
threadgroup U *local_data [[threadgroup(0)]],
|
||||||
|
uint2 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint2 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint2 lsize [[threads_per_threadgroup]]) {
|
||||||
|
auto out_idx = tid.x * lsize.x + lid.x;
|
||||||
|
|
||||||
|
if(out_idx < out_size) {
|
||||||
|
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
local_data,
|
||||||
|
out_idx,
|
||||||
|
out_idx,
|
||||||
|
reduction_size,
|
||||||
|
reduction_stride,
|
||||||
|
tid,
|
||||||
|
lid,
|
||||||
|
lsize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_col_reduce(name, itype, otype, op) \
|
||||||
|
template [[host_name("col_reduce_" #name)]] \
|
||||||
|
[[kernel]] void col_reduce<itype, otype, op>( \
|
||||||
|
const device itype *in [[buffer(0)]], \
|
||||||
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||||
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
|
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||||
|
const constant size_t& out_size [[buffer(4)]], \
|
||||||
|
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||||
|
uint2 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint2 lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint2 lsize [[threads_per_threadgroup]]);
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int NDIM, int N_READS = 16>
|
||||||
|
[[kernel]] void contiguous_strided_reduce(
|
||||||
|
const device T *in [[buffer(0)]],
|
||||||
|
device mlx_atomic<U> *out [[buffer(1)]],
|
||||||
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
|
const constant size_t& reduction_stride [[buffer(3)]],
|
||||||
|
const constant size_t& out_size [[buffer(4)]],
|
||||||
|
const device int* in_shape [[buffer(5)]],
|
||||||
|
const device size_t* in_strides [[buffer(6)]],
|
||||||
|
threadgroup U *local_data [[threadgroup(0)]],
|
||||||
|
uint2 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint2 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint2 lsize [[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
auto out_idx = tid.x * lsize.x + lid.x;
|
||||||
|
auto in_idx = elem_to_loc_nd<NDIM>(out_idx, in_shape, in_strides);
|
||||||
|
|
||||||
|
if(out_idx < out_size) {
|
||||||
|
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
local_data,
|
||||||
|
in_idx,
|
||||||
|
out_idx,
|
||||||
|
reduction_size,
|
||||||
|
reduction_stride,
|
||||||
|
tid,
|
||||||
|
lid,
|
||||||
|
lsize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||||
|
[[kernel]] void contiguous_strided_reduce(
|
||||||
|
const device T *in [[buffer(0)]],
|
||||||
|
device mlx_atomic<U> *out [[buffer(1)]],
|
||||||
|
const constant size_t& reduction_size [[buffer(2)]],
|
||||||
|
const constant size_t& reduction_stride [[buffer(3)]],
|
||||||
|
const constant size_t& out_size [[buffer(4)]],
|
||||||
|
const device int* in_shape [[buffer(5)]],
|
||||||
|
const device size_t* in_strides [[buffer(6)]],
|
||||||
|
const device size_t& in_dim [[buffer(7)]],
|
||||||
|
threadgroup U *local_data [[threadgroup(0)]],
|
||||||
|
uint2 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint2 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint2 lsize [[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
auto out_idx = tid.x * lsize.x + lid.x;
|
||||||
|
auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim);
|
||||||
|
|
||||||
|
if(out_idx < out_size) {
|
||||||
|
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
local_data,
|
||||||
|
in_idx,
|
||||||
|
out_idx,
|
||||||
|
reduction_size,
|
||||||
|
reduction_stride,
|
||||||
|
tid,
|
||||||
|
lid,
|
||||||
|
lsize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
||||||
|
template [[host_name("contiguous_strided_reduce_" #name)]] \
|
||||||
|
[[kernel]] void contiguous_strided_reduce<itype, otype, op>( \
|
||||||
|
const device itype *in [[buffer(0)]], \
|
||||||
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||||
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
|
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||||
|
const constant size_t& out_size [[buffer(4)]], \
|
||||||
|
const device int* in_shape [[buffer(5)]], \
|
||||||
|
const device size_t* in_strides [[buffer(6)]], \
|
||||||
|
const device size_t& in_dim [[buffer(7)]], \
|
||||||
|
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||||
|
uint2 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint2 lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint2 lsize [[threads_per_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \
|
||||||
|
template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \
|
||||||
|
[[kernel]] void contiguous_strided_reduce<itype, otype, op, n>( \
|
||||||
|
const device itype *in [[buffer(0)]], \
|
||||||
|
device mlx_atomic<otype> *out [[buffer(1)]], \
|
||||||
|
const constant size_t& reduction_size [[buffer(2)]], \
|
||||||
|
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||||
|
const constant size_t& out_size [[buffer(4)]], \
|
||||||
|
const device int* in_shape [[buffer(5)]], \
|
||||||
|
const device size_t* in_strides [[buffer(6)]], \
|
||||||
|
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||||
|
uint2 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint2 lid [[thread_position_in_threadgroup]], \
|
||||||
|
uint2 lsize [[threads_per_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_contiguous_strided(name, itype, otype, op) \
|
||||||
|
instantiate_contiguous_strided_helper(name, itype, otype, op) \
|
||||||
|
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \
|
||||||
|
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \
|
||||||
|
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \
|
||||||
|
instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4)
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Instantiations
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#define instantiate_reduce(name, itype, otype, op) \
|
||||||
|
instantiate_all_reduce(name, itype, otype, op) \
|
||||||
|
instantiate_row_reduce(name, itype, otype, op) \
|
||||||
|
instantiate_col_reduce(name, itype, otype, op) \
|
||||||
|
instantiate_contiguous_strided(name, itype, otype, op) \
|
||||||
|
instantiate_general_reduce(name, itype, otype, op)
|
||||||
|
|
||||||
|
#define instantiate_same_reduce(name, tname, type, op) \
|
||||||
|
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||||
|
instantiate_reduce(name ##tname, type, type, op<type>)
|
||||||
|
|
||||||
|
#define instantiate_reduce_from_types_helper(name, tname, itype, otype, op) \
|
||||||
|
instantiate_reduce(name ##tname, itype, otype, op)
|
||||||
|
|
||||||
|
#define instantiate_reduce_from_types(name, otype, op) \
|
||||||
|
instantiate_reduce_from_types_helper(name, bool_, bool, otype, op) \
|
||||||
|
instantiate_reduce_from_types_helper(name, uint8, uint8_t, otype, op) \
|
||||||
|
instantiate_reduce_from_types_helper(name, uint16, uint16_t, otype, op) \
|
||||||
|
instantiate_reduce_from_types_helper(name, uint32, uint32_t, otype, op) \
|
||||||
|
instantiate_reduce_from_types_helper(name, int8, int8_t, otype, op) \
|
||||||
|
instantiate_reduce_from_types_helper(name, int16, int16_t, otype, op) \
|
||||||
|
instantiate_reduce_from_types_helper(name, int32, int32_t, otype, op) \
|
||||||
|
instantiate_reduce_from_types_helper(name, int64, int64_t, otype, op) \
|
||||||
|
instantiate_reduce_from_types_helper(name, float16, half, otype, op) \
|
||||||
|
instantiate_reduce_from_types_helper(name, float32, float, otype, op) \
|
||||||
|
instantiate_reduce_from_types_helper(name, bfloat16, bfloat16_t, otype, op)
|
||||||
|
|
||||||
|
// special case bool with larger output type
|
||||||
|
instantiate_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||||
|
instantiate_same_reduce(sum, uint8, uint8_t, Sum)
|
||||||
|
instantiate_same_reduce(sum, uint16, uint16_t, Sum)
|
||||||
|
instantiate_same_reduce(sum, uint32, uint32_t, Sum)
|
||||||
|
instantiate_same_reduce(sum, int8, int8_t, Sum)
|
||||||
|
instantiate_same_reduce(sum, int16, int16_t, Sum)
|
||||||
|
instantiate_same_reduce(sum, int32, int32_t, Sum)
|
||||||
|
instantiate_same_reduce(sum, float16, half, Sum)
|
||||||
|
instantiate_same_reduce(sum, float32, float, Sum)
|
||||||
|
|
||||||
|
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
|
||||||
|
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
|
||||||
|
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
|
||||||
|
instantiate_same_reduce(prod, int8, int8_t, Prod)
|
||||||
|
instantiate_same_reduce(prod, int16, int16_t, Prod)
|
||||||
|
instantiate_same_reduce(prod, int32, int32_t, Prod)
|
||||||
|
instantiate_same_reduce(prod, float16, half, Prod)
|
||||||
|
instantiate_same_reduce(prod, float32, float, Prod)
|
||||||
|
|
||||||
|
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
|
||||||
|
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
|
||||||
|
|
||||||
|
instantiate_init_reduce(andbool_, bool, And)
|
||||||
|
instantiate_reduce_from_types(and, bool, And)
|
||||||
|
|
||||||
|
instantiate_init_reduce(orbool_, bool, Or)
|
||||||
|
instantiate_reduce_from_types(or, bool, Or)
|
||||||
|
|
||||||
|
// Compiler segfaulted with the names "min" or "max" ...
|
||||||
|
instantiate_same_reduce(min_, uint8, uint8_t, Min)
|
||||||
|
instantiate_same_reduce(min_, uint16, uint16_t, Min)
|
||||||
|
instantiate_same_reduce(min_, uint32, uint32_t, Min)
|
||||||
|
instantiate_same_reduce(min_, int8, int8_t, Min)
|
||||||
|
instantiate_same_reduce(min_, int16, int16_t, Min)
|
||||||
|
instantiate_same_reduce(min_, int32, int32_t, Min)
|
||||||
|
instantiate_same_reduce(min_, float16, half, Min)
|
||||||
|
instantiate_same_reduce(min_, float32, float, Min)
|
||||||
|
|
||||||
|
instantiate_same_reduce(max_, uint8, uint8_t, Max)
|
||||||
|
instantiate_same_reduce(max_, uint16, uint16_t, Max)
|
||||||
|
instantiate_same_reduce(max_, uint32, uint32_t, Max)
|
||||||
|
instantiate_same_reduce(max_, int8, int8_t, Max)
|
||||||
|
instantiate_same_reduce(max_, int16, int16_t, Max)
|
||||||
|
instantiate_same_reduce(max_, int32, int32_t, Max)
|
||||||
|
instantiate_same_reduce(max_, float16, half, Max)
|
||||||
|
instantiate_same_reduce(max_, float32, float, Max)
|
||||||
|
|
||||||
|
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
|
||||||
|
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)
|
284
mlx/backend/metal/kernels/unary.metal
Normal file
284
mlx/backend/metal/kernels/unary.metal
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
#include <metal_integer>
|
||||||
|
#include <metal_math>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/erf.h"
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
|
||||||
|
struct Abs {
|
||||||
|
template <typename T> T operator()(T x) { return metal::abs(x); };
|
||||||
|
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||||
|
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||||
|
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||||
|
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||||
|
template <> bool operator()(bool x) { return x; };
|
||||||
|
template <> complex64_t operator()(complex64_t x) {
|
||||||
|
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcCos {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::acos(x); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcCosh {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::acosh(x); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcSin {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::asin(x); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcSinh {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::asinh(x); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcTan {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::atan(x); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcTanh {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Cos {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {
|
||||||
|
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||||
|
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Cosh {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::cosh(x); };
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {
|
||||||
|
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||||
|
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Erf {
|
||||||
|
template <typename T> T operator()(T x) { return static_cast<T>(erf(static_cast<float>(x))); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ErfInv {
|
||||||
|
template <typename T> T operator()(T x) { return static_cast<T>(erfinv(static_cast<float>(x))); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Exp {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::exp(x); };
|
||||||
|
template <> complex64_t operator()(complex64_t x) {
|
||||||
|
auto m = metal::precise::exp(x.real);
|
||||||
|
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::log(x); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log2 {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::log2(x); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log10 {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::log10(x); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log1p {
|
||||||
|
template <typename T> T operator()(T x) { return log1p(x); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogicalNot {
|
||||||
|
template <typename T> T operator()(T x) { return !x; };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Negative {
|
||||||
|
template <typename T> T operator()(T x) { return -x; };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sigmoid {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||||
|
return (x < 0) ? 1 - y : y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sign {
|
||||||
|
template <typename T> T operator()(T x) { return (x > T(0)) - (x < T(0)); };
|
||||||
|
template <> uint32_t operator()(uint32_t x) { return x != 0; };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sin {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::sin(x); };
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {
|
||||||
|
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||||
|
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sinh {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::sinh(x); };
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {
|
||||||
|
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||||
|
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Square {
|
||||||
|
template <typename T> T operator()(T x) { return x * x; };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sqrt {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::sqrt(x); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Rsqrt {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::rsqrt(x); };
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Tan {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::tan(x); };
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
float tan_a = metal::precise::tan(x.real);
|
||||||
|
float tanh_b = metal::precise::tanh(x.imag);
|
||||||
|
float t1 = tan_a * tanh_b;
|
||||||
|
float denom = 1. + t1 * t1;
|
||||||
|
return {
|
||||||
|
(tan_a - tanh_b * t1) / denom,
|
||||||
|
(tanh_b + tan_a * t1) / denom
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Tanh {
|
||||||
|
template <typename T> T operator()(T x) { return metal::precise::tanh(x); };
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
float tanh_a = metal::precise::tanh(x.real);
|
||||||
|
float tan_b = metal::precise::tan(x.imag);
|
||||||
|
float t1 = tanh_a * tan_b;
|
||||||
|
float denom = 1. + t1 * t1;
|
||||||
|
return {
|
||||||
|
(tanh_a + tan_b * t1) / denom,
|
||||||
|
(tan_b - tanh_a * t1) / denom
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
[[kernel]] void unary_op_v(
|
||||||
|
device const T* in,
|
||||||
|
device T* out,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
out[index] = Op()(in[index]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
[[kernel]] void unary_op_g(
|
||||||
|
device const T* in,
|
||||||
|
device T* out,
|
||||||
|
device const int* in_shape,
|
||||||
|
device const size_t* in_strides,
|
||||||
|
device const int& ndim,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
|
||||||
|
out[index] = Op()(in[idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_unary_v(name, type, op) \
|
||||||
|
template [[host_name(name)]] \
|
||||||
|
[[kernel]] void unary_op_v<type, op>( \
|
||||||
|
device const type* in, \
|
||||||
|
device type* out, \
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
#define instantiate_unary_g(name, type, op) \
|
||||||
|
template [[host_name(name)]] \
|
||||||
|
[[kernel]] void unary_op_g<type, op>( \
|
||||||
|
device const type* in, \
|
||||||
|
device type* out, \
|
||||||
|
device const int* in_shape, \
|
||||||
|
device const size_t* in_strides, \
|
||||||
|
device const int& ndim, \
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
#define instantiate_unary_all(name, tname, type, op) \
|
||||||
|
instantiate_unary_v("v" #name #tname, type, op) \
|
||||||
|
instantiate_unary_g("g" #name #tname, type, op)
|
||||||
|
|
||||||
|
#define instantiate_unary_float(name, op) \
|
||||||
|
instantiate_unary_all(name, float16, half, op) \
|
||||||
|
instantiate_unary_all(name, float32, float, op) \
|
||||||
|
instantiate_unary_all(name, bfloat16, bfloat16_t, op) \
|
||||||
|
|
||||||
|
#define instantiate_unary_types(name, op) \
|
||||||
|
instantiate_unary_all(name, bool_, bool, op) \
|
||||||
|
instantiate_unary_all(name, uint8, uint8_t, op) \
|
||||||
|
instantiate_unary_all(name, uint16, uint16_t, op) \
|
||||||
|
instantiate_unary_all(name, uint32, uint32_t, op) \
|
||||||
|
instantiate_unary_all(name, uint64, uint64_t, op) \
|
||||||
|
instantiate_unary_all(name, int8, int8_t, op) \
|
||||||
|
instantiate_unary_all(name, int16, int16_t, op) \
|
||||||
|
instantiate_unary_all(name, int32, int32_t, op) \
|
||||||
|
instantiate_unary_all(name, int64, int64_t, op) \
|
||||||
|
instantiate_unary_float(name, op)
|
||||||
|
|
||||||
|
instantiate_unary_types(abs, Abs)
|
||||||
|
instantiate_unary_float(arccos, ArcCos)
|
||||||
|
instantiate_unary_float(arccosh, ArcCosh)
|
||||||
|
instantiate_unary_float(arcsin, ArcSin)
|
||||||
|
instantiate_unary_float(arcsinh, ArcSinh)
|
||||||
|
instantiate_unary_float(arctan, ArcTan)
|
||||||
|
instantiate_unary_float(arctanh, ArcTanh)
|
||||||
|
instantiate_unary_float(cos, Cos)
|
||||||
|
instantiate_unary_float(cosh, Cosh)
|
||||||
|
instantiate_unary_float(exp, Exp)
|
||||||
|
instantiate_unary_float(log, Log)
|
||||||
|
instantiate_unary_float(log2, Log2)
|
||||||
|
instantiate_unary_float(log10, Log10)
|
||||||
|
instantiate_unary_float(log1p, Log1p)
|
||||||
|
instantiate_unary_types(neg, Negative)
|
||||||
|
instantiate_unary_float(sigmoid, Sigmoid)
|
||||||
|
instantiate_unary_float(erf, Erf)
|
||||||
|
instantiate_unary_float(erfinv, ErfInv)
|
||||||
|
instantiate_unary_types(sign, Sign)
|
||||||
|
instantiate_unary_float(sin, Sin)
|
||||||
|
instantiate_unary_float(sinh, Sinh)
|
||||||
|
instantiate_unary_types(square, Square)
|
||||||
|
instantiate_unary_float(sqrt, Sqrt)
|
||||||
|
instantiate_unary_float(rsqrt, Rsqrt)
|
||||||
|
instantiate_unary_float(tan, Tan)
|
||||||
|
instantiate_unary_float(tanh, Tanh)
|
||||||
|
|
||||||
|
instantiate_unary_all(abs, complex64, complex64_t, Abs)
|
||||||
|
instantiate_unary_all(cos, complex64, complex64_t, Cos)
|
||||||
|
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
|
||||||
|
instantiate_unary_all(exp, complex64, complex64_t, Exp)
|
||||||
|
instantiate_unary_all(neg, complex64, complex64_t, Negative)
|
||||||
|
instantiate_unary_all(sin, complex64, complex64_t, Sin)
|
||||||
|
instantiate_unary_all(sinh, complex64, complex64_t, Sinh)
|
||||||
|
instantiate_unary_all(tan, complex64, complex64_t, Tan)
|
||||||
|
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
|
||||||
|
|
||||||
|
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
|
369
mlx/backend/metal/reduce.cpp
Normal file
369
mlx/backend/metal/reduce.cpp
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/reduce.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
// Case wise reduce dispatch
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// All Reduce
|
||||||
|
void all_reduce_dispatch(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const std::string& op_name,
|
||||||
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||||||
|
metal::Device& d) {
|
||||||
|
// Get kernel and encode buffers
|
||||||
|
size_t in_size = in.size();
|
||||||
|
auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in));
|
||||||
|
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||||
|
|
||||||
|
// Set grid dimensions
|
||||||
|
|
||||||
|
// We make sure each thread has enough to do by making it read in
|
||||||
|
// atleast n_reads inputs
|
||||||
|
int n_reads = REDUCE_N_READS;
|
||||||
|
|
||||||
|
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
||||||
|
// input
|
||||||
|
uint mod_in_size = (in_size + n_reads - 1) / n_reads;
|
||||||
|
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
thread_group_size =
|
||||||
|
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
|
||||||
|
|
||||||
|
// If the number of thread groups needed exceeds 1024, we reuse threads groups
|
||||||
|
uint n_thread_groups =
|
||||||
|
(mod_in_size + thread_group_size - 1) / thread_group_size;
|
||||||
|
n_thread_groups = std::min(n_thread_groups, 1024u);
|
||||||
|
uint nthreads = n_thread_groups * thread_group_size;
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void row_reduce_dispatch(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const std::string& op_name,
|
||||||
|
const std::vector<int>& axes_,
|
||||||
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||||||
|
metal::Device& d) {
|
||||||
|
auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in));
|
||||||
|
|
||||||
|
int n_reads = REDUCE_N_READS;
|
||||||
|
size_t reduction_size = in.size() / out.size();
|
||||||
|
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||||
|
|
||||||
|
// Each thread group is responsible for 1 output
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
thread_group_size =
|
||||||
|
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
|
||||||
|
|
||||||
|
// Align thread group size with simd_size
|
||||||
|
uint simd_size = kernel->threadExecutionWidth();
|
||||||
|
thread_group_size =
|
||||||
|
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||||
|
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||||
|
|
||||||
|
// Launch enough thread groups for each output
|
||||||
|
size_t n_threads = out.size() * thread_group_size;
|
||||||
|
MTL::Size grid_dims = MTL::Size(n_threads, 1, 1);
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void col_reduce_dispatch(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const std::string& op_name,
|
||||||
|
const std::vector<int>& axes_,
|
||||||
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||||||
|
metal::Device& d) {
|
||||||
|
std::ostringstream kernel_name;
|
||||||
|
|
||||||
|
bool encode_in_shape = false;
|
||||||
|
bool encode_ndim = false;
|
||||||
|
|
||||||
|
// If the slowest moving axis can be merged into the reductions,
|
||||||
|
// we call the column reduce kernel
|
||||||
|
// In this case, a linear index in the output corresponds to the
|
||||||
|
// linear index in the input where the reduction starts
|
||||||
|
if (axes_[axes_.size() - 1] == (axes_.size() - 1)) {
|
||||||
|
kernel_name << "col_reduce_" << op_name << type_to_name(in);
|
||||||
|
}
|
||||||
|
// Otherwise, while all the reduction axes can be merged, the mapping between
|
||||||
|
// indices in the output and input require resolving using shapes and strides
|
||||||
|
else {
|
||||||
|
kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in);
|
||||||
|
encode_in_shape = true;
|
||||||
|
|
||||||
|
// We check for a viable template with the required number of dimensions
|
||||||
|
// we only care about encoding non-reduced shapes and strides in the input
|
||||||
|
size_t non_reducing_dims = in.ndim() - axes_.size();
|
||||||
|
if (non_reducing_dims >= 1 &&
|
||||||
|
non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) {
|
||||||
|
kernel_name << "_dim_" << non_reducing_dims;
|
||||||
|
} else {
|
||||||
|
encode_ndim = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel = d.get_kernel(kernel_name.str());
|
||||||
|
size_t in_size = in.size();
|
||||||
|
size_t out_size = out.size();
|
||||||
|
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
|
||||||
|
// Calculate the number of inputs to reduce and the stride b/w them
|
||||||
|
size_t reduction_size = 1;
|
||||||
|
size_t in_ndim = in.ndim();
|
||||||
|
size_t reduction_stride = in_size;
|
||||||
|
|
||||||
|
for (int i : axes_) {
|
||||||
|
reduction_size *= in.shape(i);
|
||||||
|
reduction_stride = std::min(reduction_stride, in.strides()[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||||
|
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||||
|
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||||
|
if (encode_in_shape) {
|
||||||
|
// Obtain the non-reducing shape and strides of the input to encode
|
||||||
|
std::vector<int> inp_shape_mod;
|
||||||
|
std::vector<size_t> inp_strides_mod;
|
||||||
|
|
||||||
|
for (size_t i = 0, j = 0; i < in.ndim(); i++) {
|
||||||
|
if (j < axes_.size() && axes_[j] == i) {
|
||||||
|
j++;
|
||||||
|
} else {
|
||||||
|
inp_shape_mod.push_back(in.shape(i));
|
||||||
|
inp_strides_mod.push_back(in.strides()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ndim = inp_shape_mod.size();
|
||||||
|
|
||||||
|
compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5);
|
||||||
|
compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6);
|
||||||
|
|
||||||
|
if (encode_ndim) {
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(size_t), 7);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select block dimensions
|
||||||
|
|
||||||
|
// Each thread reads 16 inputs to give it more work
|
||||||
|
uint n_inputs_per_thread = REDUCE_N_READS;
|
||||||
|
uint n_threads_per_output =
|
||||||
|
(reduction_size + n_inputs_per_thread - 1) / n_inputs_per_thread;
|
||||||
|
|
||||||
|
// We spread outputs over the x dimension and inputs over the y dimension
|
||||||
|
// Threads with the same lid.x in a given threadgroup work on the same
|
||||||
|
// output and each thread in the y dimension accumlates for that output
|
||||||
|
uint threadgroup_dim_x = std::min(out_size, 128ul);
|
||||||
|
uint threadgroup_dim_y =
|
||||||
|
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
|
||||||
|
threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y);
|
||||||
|
|
||||||
|
uint n_threadgroups_x =
|
||||||
|
(out_size + threadgroup_dim_x - 1) / threadgroup_dim_x;
|
||||||
|
|
||||||
|
uint n_threadgroups_y =
|
||||||
|
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
|
||||||
|
|
||||||
|
// Launch enough thread groups for each output
|
||||||
|
MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1);
|
||||||
|
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
|
||||||
|
|
||||||
|
// We set shared memory to be exploited here for reductions within a
|
||||||
|
// threadgroup - each thread must be able to update its accumulated output
|
||||||
|
// Note: Each threadgroup should have 32kB of data in threadgroup memory
|
||||||
|
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
|
||||||
|
// This should be fine for floats, but we might need to revisit
|
||||||
|
// if we ever come to doubles. In that case, we should also cut
|
||||||
|
// down the number of threads we launch in a threadgroup
|
||||||
|
compute_encoder->setThreadgroupMemoryLength(
|
||||||
|
threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 0);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void general_reduce_dispatch(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const std::string& op_name,
|
||||||
|
const std::vector<int>& axes_,
|
||||||
|
MTL::ComputeCommandEncoder* compute_encoder,
|
||||||
|
metal::Device& d) {
|
||||||
|
bool encode_ndim = true;
|
||||||
|
std::ostringstream kernel_name;
|
||||||
|
kernel_name << "general_reduce_" << op_name << type_to_name(in);
|
||||||
|
|
||||||
|
// Check for specialzed kernels for input ndim
|
||||||
|
if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) {
|
||||||
|
kernel_name << "_dim_" << in.ndim();
|
||||||
|
encode_ndim = false;
|
||||||
|
}
|
||||||
|
auto kernel = d.get_kernel(kernel_name.str());
|
||||||
|
size_t in_size = in.size();
|
||||||
|
size_t ndim = in.ndim();
|
||||||
|
|
||||||
|
// We set the reducing strides to 0 to induce collisions for the reduction
|
||||||
|
std::vector<size_t> out_strides(ndim);
|
||||||
|
size_t stride = 1;
|
||||||
|
for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) {
|
||||||
|
if (j >= 0 && axes_[j] == i) {
|
||||||
|
out_strides[i] = 0;
|
||||||
|
--j;
|
||||||
|
} else {
|
||||||
|
out_strides[i] = stride;
|
||||||
|
stride *= in.shape(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, in, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2);
|
||||||
|
compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3);
|
||||||
|
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
|
||||||
|
if (encode_ndim) {
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size > in_size) {
|
||||||
|
thread_group_size = in_size;
|
||||||
|
}
|
||||||
|
size_t nthreads = in_size;
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
// Main reduce dispatch
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
// TODO: Allow specific row and column reductions with types disabled
|
||||||
|
// due to atomics ?
|
||||||
|
if (size_of(in.dtype()) == 8) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[Reduce::eval_gpu] Does not support " << in.dtype();
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure no identity reductions trickle down here
|
||||||
|
assert(!axes_.empty());
|
||||||
|
|
||||||
|
// Continue with reduction operation
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
std::string op_name;
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Reduce::And:
|
||||||
|
op_name = "and";
|
||||||
|
break;
|
||||||
|
case Reduce::Or:
|
||||||
|
op_name = "or";
|
||||||
|
break;
|
||||||
|
case Reduce::Sum:
|
||||||
|
op_name = "sum";
|
||||||
|
break;
|
||||||
|
case Reduce::Prod:
|
||||||
|
op_name = out.dtype() == bool_ ? "and" : "prod";
|
||||||
|
break;
|
||||||
|
case Reduce::Min:
|
||||||
|
op_name = out.dtype() == bool_ ? "and" : "min_";
|
||||||
|
break;
|
||||||
|
case Reduce::Max:
|
||||||
|
op_name = out.dtype() == bool_ ? "or" : "max_";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize output
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
{
|
||||||
|
auto kernel = d.get_kernel("i" + op_name + type_to_name(out));
|
||||||
|
size_t nthreads = out.size();
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size > nthreads) {
|
||||||
|
thread_group_size = nthreads;
|
||||||
|
}
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
set_array_buffer(compute_encoder, out, 0);
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduce
|
||||||
|
{
|
||||||
|
// Check for contiguous data
|
||||||
|
if (in.size() == in.data_size() &&
|
||||||
|
(in.flags().row_contiguous || in.flags().col_contiguous)) {
|
||||||
|
// Go to all reduce if reducing over all axes
|
||||||
|
if (axes_.size() == in.ndim()) {
|
||||||
|
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Use specialized kernels if the input is row contiguous and
|
||||||
|
// the reducing axes can be merged into one
|
||||||
|
else if (
|
||||||
|
in.flags().row_contiguous && in.strides().back() == 1 &&
|
||||||
|
(axes_.back() - axes_.front()) == axes_.size() - 1) {
|
||||||
|
// If the fastest moving axis is being reduced, go to row reduce
|
||||||
|
if (axes_[0] == (in.ndim() - axes_.size())) {
|
||||||
|
row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Otherwise go to to generalized strided reduce
|
||||||
|
// Note: bool isn't support here yet due to the use of atomics
|
||||||
|
// once that is updated, this should be the else condition of this
|
||||||
|
// branch
|
||||||
|
else if (in.dtype() != bool_) {
|
||||||
|
col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fall back to the general case
|
||||||
|
general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
15
mlx/backend/no_metal/allocator.cpp
Normal file
15
mlx/backend/no_metal/allocator.cpp
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
|
||||||
|
namespace mlx::core::allocator {
|
||||||
|
|
||||||
|
Allocator& allocator() {
|
||||||
|
static CommonAllocator allocator_;
|
||||||
|
return allocator_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* Buffer::raw_ptr() {
|
||||||
|
return ptr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::allocator
|
149
mlx/fft.h
Normal file
149
mlx/fft.h
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
|
#include "array.h"
|
||||||
|
#include "device.h"
|
||||||
|
#include "stream.h"
|
||||||
|
|
||||||
|
namespace mlx::core::fft {
|
||||||
|
|
||||||
|
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
||||||
|
|
||||||
|
/** Compute the n-dimensional Fourier Transform. */
|
||||||
|
array fftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array fftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
|
||||||
|
array fftn(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Compute the n-dimensional inverse Fourier Transform. */
|
||||||
|
array ifftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array ifftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array ifftn(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Compute the one-dimensional Fourier Transform. */
|
||||||
|
inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
||||||
|
return fftn(a, {n}, {axis}, s);
|
||||||
|
}
|
||||||
|
inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
||||||
|
return fftn(a, {axis}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Compute the one-dimensional inverse Fourier Transform. */
|
||||||
|
inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
||||||
|
return ifftn(a, {n}, {axis}, s);
|
||||||
|
}
|
||||||
|
inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
||||||
|
return ifftn(a, {axis}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Compute the two-dimensional Fourier Transform. */
|
||||||
|
inline array fft2(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return fftn(a, n, axes, s);
|
||||||
|
}
|
||||||
|
inline array fft2(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes = {-2, -1},
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return fftn(a, axes, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Compute the two-dimensional inverse Fourier Transform. */
|
||||||
|
inline array ifft2(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return ifftn(a, n, axes, s);
|
||||||
|
}
|
||||||
|
inline array ifft2(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes = {-2, -1},
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return ifftn(a, axes, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Compute the n-dimensional Fourier Transform on a real input. */
|
||||||
|
array rfftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array rfftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array rfftn(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Compute the n-dimensional inverse of `rfftn`. */
|
||||||
|
array irfftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array irfftn(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array irfftn(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Compute the one-dimensional Fourier Transform on a real input. */
|
||||||
|
inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
||||||
|
return rfftn(a, {n}, {axis}, s);
|
||||||
|
}
|
||||||
|
inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
||||||
|
return rfftn(a, {axis}, s);
|
||||||
|
}
|
||||||
|
/** Compute the one-dimensional inverse of `rfft`. */
|
||||||
|
inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
||||||
|
return irfftn(a, {n}, {axis}, s);
|
||||||
|
}
|
||||||
|
inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
||||||
|
return irfftn(a, {axis}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Compute the two-dimensional Fourier Transform on a real input. */
|
||||||
|
inline array rfft2(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return rfftn(a, n, axes, s);
|
||||||
|
}
|
||||||
|
inline array rfft2(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes = {-2, -1},
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return rfftn(a, axes, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Compute the two-dimensional inverse of `rfft2`. */
|
||||||
|
inline array irfft2(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& n,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return irfftn(a, n, axes, s);
|
||||||
|
}
|
||||||
|
inline array irfft2(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes = {-2, -1},
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return irfftn(a, axes, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::fft
|
240
mlx/load.cpp
Normal file
240
mlx/load.cpp
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
#include <limits>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "mlx/load.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
// Adapted from
|
||||||
|
// https://github.com/angeloskath/supervised-lda/blob/master/include/ldaplusplus/NumpyFormat.hpp
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static constexpr uint8_t MAGIC[] = {
|
||||||
|
0x93,
|
||||||
|
0x4e,
|
||||||
|
0x55,
|
||||||
|
0x4d,
|
||||||
|
0x50,
|
||||||
|
0x59,
|
||||||
|
};
|
||||||
|
|
||||||
|
inline bool is_big_endian_() {
|
||||||
|
union ByteOrder {
|
||||||
|
int32_t i;
|
||||||
|
uint8_t c[4];
|
||||||
|
};
|
||||||
|
ByteOrder b = {0x01234567};
|
||||||
|
|
||||||
|
return b.c[0] == 0x01;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
/** Save array to out stream in .npy format */
|
||||||
|
void save(std::shared_ptr<io::Writer> out_stream, array a, bool retain_graph) {
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
// Check array
|
||||||
|
|
||||||
|
a.eval(retain_graph);
|
||||||
|
|
||||||
|
if (a.nbytes() == 0) {
|
||||||
|
throw std::invalid_argument("[save] cannot serialize an empty array");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!a.flags().contiguous) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[save] cannot serialize a non-contiguous array");
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
// Check file
|
||||||
|
if (!out_stream->good() || !out_stream->is_open()) {
|
||||||
|
throw std::runtime_error("[save] Failed to open " + out_stream->label());
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
// Prepare header
|
||||||
|
std::ostringstream magic_ver_len;
|
||||||
|
magic_ver_len.write(reinterpret_cast<const char*>(MAGIC), 6);
|
||||||
|
|
||||||
|
std::string fortran_order = a.flags().col_contiguous ? "True" : "False";
|
||||||
|
std::ostringstream header;
|
||||||
|
header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "',"
|
||||||
|
<< " 'fortran_order': " << fortran_order << ","
|
||||||
|
<< " 'shape': (";
|
||||||
|
for (auto i : a.shape()) {
|
||||||
|
header << i << ", ";
|
||||||
|
}
|
||||||
|
header << ")}";
|
||||||
|
|
||||||
|
size_t header_len = static_cast<size_t>(header.tellp());
|
||||||
|
bool is_v1 = header_len + 15 < std::numeric_limits<uint16_t>::max();
|
||||||
|
|
||||||
|
// Pad out magic + version + header_len + header + \n to be divisible by 16
|
||||||
|
size_t padding = (6 + 2 + (2 + 2 * is_v1) + header_len + 1) % 16;
|
||||||
|
|
||||||
|
header << std::string(padding, ' ') << '\n';
|
||||||
|
|
||||||
|
if (is_v1) {
|
||||||
|
magic_ver_len << (char)0x01 << (char)0x00;
|
||||||
|
|
||||||
|
uint16_t v1_header_len = header.tellp();
|
||||||
|
const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);
|
||||||
|
|
||||||
|
if (!is_big_endian_()) {
|
||||||
|
magic_ver_len.write(len_bytes, 2);
|
||||||
|
} else {
|
||||||
|
magic_ver_len.write(len_bytes + 1, 1);
|
||||||
|
magic_ver_len.write(len_bytes, 1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
magic_ver_len << (char)0x02 << (char)0x00;
|
||||||
|
|
||||||
|
uint32_t v2_header_len = header.tellp();
|
||||||
|
const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);
|
||||||
|
|
||||||
|
if (!is_big_endian_()) {
|
||||||
|
magic_ver_len.write(len_bytes, 4);
|
||||||
|
} else {
|
||||||
|
magic_ver_len.write(len_bytes + 3, 1);
|
||||||
|
magic_ver_len.write(len_bytes + 2, 1);
|
||||||
|
magic_ver_len.write(len_bytes + 1, 1);
|
||||||
|
magic_ver_len.write(len_bytes, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
// Serialize array
|
||||||
|
|
||||||
|
out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length());
|
||||||
|
out_stream->write(header.str().c_str(), header.str().length());
|
||||||
|
out_stream->write(a.data<char>(), a.nbytes());
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Save array to file in .npy format */
|
||||||
|
void save(const std::string& file_, array a, bool retain_graph) {
|
||||||
|
// Open and check file
|
||||||
|
std::string file = file_;
|
||||||
|
|
||||||
|
// Add .npy to file name if it is not there
|
||||||
|
if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy")
|
||||||
|
file += ".npy";
|
||||||
|
|
||||||
|
// Serialize array
|
||||||
|
save(std::make_shared<io::FileWriter>(file), a, retain_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Load array from reader in .npy format */
|
||||||
|
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
// Open and check file
|
||||||
|
if (!in_stream->good() || !in_stream->is_open()) {
|
||||||
|
throw std::runtime_error("[load] Failed to open " + in_stream->label());
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
// Read header and prepare array details
|
||||||
|
|
||||||
|
// Read and check magic
|
||||||
|
char read_magic_and_ver[8];
|
||||||
|
in_stream->read(read_magic_and_ver, 8);
|
||||||
|
if (std::memcmp(read_magic_and_ver, MAGIC, 6) != 0) {
|
||||||
|
throw std::runtime_error("[load] Invalid header in " + in_stream->label());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and check version
|
||||||
|
if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[load] Unsupport npy format version in " + in_stream->label());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read header len and header
|
||||||
|
int header_len_size = read_magic_and_ver[6] == 1 ? 2 : 4;
|
||||||
|
size_t header_len;
|
||||||
|
|
||||||
|
if (header_len_size == 2) {
|
||||||
|
uint16_t v1_header_len;
|
||||||
|
in_stream->read(reinterpret_cast<char*>(&v1_header_len), header_len_size);
|
||||||
|
header_len = v1_header_len;
|
||||||
|
} else {
|
||||||
|
uint32_t v2_header_len;
|
||||||
|
in_stream->read(reinterpret_cast<char*>(&v2_header_len), header_len_size);
|
||||||
|
header_len = v2_header_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the header
|
||||||
|
std::vector<char> buffer(header_len + 1);
|
||||||
|
in_stream->read(&buffer[0], header_len);
|
||||||
|
buffer[header_len] = 0;
|
||||||
|
std::string header(&buffer[0]);
|
||||||
|
|
||||||
|
// Read data type from header
|
||||||
|
std::string dtype_str = header.substr(11, 3);
|
||||||
|
bool read_is_big_endian = dtype_str[0] == '>';
|
||||||
|
Dtype dtype = dtype_from_array_protocol(dtype_str);
|
||||||
|
|
||||||
|
// Read contiguity order
|
||||||
|
bool col_contiguous = header[34] == 'T';
|
||||||
|
|
||||||
|
// Read array shape from header
|
||||||
|
std::vector<int> shape;
|
||||||
|
|
||||||
|
size_t st = header.find_last_of('(') + 1;
|
||||||
|
size_t ed = header.find_last_of(')');
|
||||||
|
std::string shape_str = header.substr(st, ed - st);
|
||||||
|
|
||||||
|
while (!shape_str.empty()) {
|
||||||
|
// Read current number and get position of comma
|
||||||
|
size_t pos;
|
||||||
|
int dim = std::stoi(shape_str, &pos);
|
||||||
|
shape.push_back(dim);
|
||||||
|
|
||||||
|
// Skip the comma and space and read the next number
|
||||||
|
if (pos + 2 <= shape_str.length())
|
||||||
|
shape_str = shape_str.substr(pos + 2);
|
||||||
|
else {
|
||||||
|
shape_str = shape_str.substr(pos);
|
||||||
|
if (!shape_str.empty() && shape_str != " " && shape_str != ",") {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[load] Unknown error while parsing header in " +
|
||||||
|
in_stream->label());
|
||||||
|
}
|
||||||
|
shape_str = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
// Build primitive
|
||||||
|
|
||||||
|
size_t offset = 8 + header_len_size + header.length();
|
||||||
|
bool swap_endianness = read_is_big_endian != is_big_endian_();
|
||||||
|
|
||||||
|
if (col_contiguous) {
|
||||||
|
std::reverse(shape.begin(), shape.end());
|
||||||
|
}
|
||||||
|
auto loaded_array = array(
|
||||||
|
shape,
|
||||||
|
dtype,
|
||||||
|
std::make_unique<Load>(to_stream(s), in_stream, offset, swap_endianness),
|
||||||
|
std::vector<array>{});
|
||||||
|
if (col_contiguous) {
|
||||||
|
loaded_array = transpose(loaded_array, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
return loaded_array;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Load array from file in .npy format */
|
||||||
|
array load(const std::string& file, StreamOrDevice s) {
|
||||||
|
return load(std::make_shared<io::FileReader>(file), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
2265
mlx/primitives.cpp
Normal file
2265
mlx/primitives.cpp
Normal file
File diff suppressed because it is too large
Load Diff
778
mlx/transforms.cpp
Normal file
778
mlx/transforms.cpp
Normal file
@ -0,0 +1,778 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <future>
|
||||||
|
#include <map>
|
||||||
|
#include <numeric>
|
||||||
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
#include "mlx/transforms.h"
|
||||||
|
#include "mlx/transforms_impl.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void simplify(const std::vector<array>& outputs) {
|
||||||
|
std::function<void(const array&)> recurse;
|
||||||
|
std::queue<array> tape;
|
||||||
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
|
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
||||||
|
parents_map;
|
||||||
|
|
||||||
|
// Helpers to identify identical scalars
|
||||||
|
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
||||||
|
auto is_scalar = [](const array& a) {
|
||||||
|
return a.is_evaled() && a.ndim() == 0;
|
||||||
|
};
|
||||||
|
auto get_scalar_rep = [](const array& a) {
|
||||||
|
uint64_t v = 0;
|
||||||
|
int dtype;
|
||||||
|
switch (a.dtype().size) {
|
||||||
|
case 1:
|
||||||
|
v = *a.data<uint8_t>();
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
v = *a.data<uint32_t>();
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
v = *a.data<uint64_t>();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return std::make_pair(v, a.dtype().val);
|
||||||
|
};
|
||||||
|
|
||||||
|
// DFS the graph to log the parents
|
||||||
|
recurse = [&](const array& a) {
|
||||||
|
auto id = a.id();
|
||||||
|
if (cache.find(id) != cache.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < a.inputs().size(); i++) {
|
||||||
|
auto& in = a.inputs()[i];
|
||||||
|
parents_map[in.id()].push_back({a, i});
|
||||||
|
recurse(in);
|
||||||
|
}
|
||||||
|
cache.insert(id);
|
||||||
|
tape.push(a);
|
||||||
|
if (is_scalar(a)) {
|
||||||
|
scalars.insert({get_scalar_rep(a), a});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (auto& a : outputs) {
|
||||||
|
recurse(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper that fuses two arrays in the graph by setting the parents of the
|
||||||
|
// source to point to the destination
|
||||||
|
auto fuse = [&](array& dst, array& src) {
|
||||||
|
auto src_parents = parents_map.find(src.id());
|
||||||
|
if (src_parents == parents_map.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& pairs = parents_map[dst.id()];
|
||||||
|
for (auto& parent : src_parents->second) {
|
||||||
|
parent.first.editable_inputs()[parent.second] = dst;
|
||||||
|
pairs.push_back(parent);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Walk the graph
|
||||||
|
cache.clear();
|
||||||
|
|
||||||
|
// Depth-1 array equivalence check.
|
||||||
|
auto array_equivalent = [](const array& a, const array& b) {
|
||||||
|
if (!a.has_primitive() || !b.has_primitive()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const auto& pa = a.primitive();
|
||||||
|
const auto& pb = b.primitive();
|
||||||
|
if (typeid(pa) != typeid(pb)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a.inputs().size() != b.inputs().size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < a.inputs().size(); i++) {
|
||||||
|
if (a.inputs()[i].id() != b.inputs()[i].id()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pa.is_equivalent(pb);
|
||||||
|
};
|
||||||
|
|
||||||
|
while (!tape.empty()) {
|
||||||
|
auto arr = std::move(tape.front());
|
||||||
|
tape.pop();
|
||||||
|
|
||||||
|
if (cache.find(arr.id()) != cache.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we can fuse scalars
|
||||||
|
if (is_scalar(arr)) {
|
||||||
|
auto scalar = scalars.find(get_scalar_rep(arr));
|
||||||
|
if (scalar->second.id() != arr.id()) {
|
||||||
|
fuse(scalar->second, arr);
|
||||||
|
arr = scalar->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we can fuse the parents of this array
|
||||||
|
auto parents = parents_map.find(arr.id());
|
||||||
|
if (parents != parents_map.end()) {
|
||||||
|
std::vector<bool> mask(parents->second.size(), false);
|
||||||
|
auto N = parents->second.size();
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
if (mask[i]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (int j = i + 1; j < N; j++) {
|
||||||
|
if (mask[j]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto& src = parents->second[j].first;
|
||||||
|
auto& dst = parents->second[i].first;
|
||||||
|
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
||||||
|
cache.insert(src.id());
|
||||||
|
fuse(dst, src);
|
||||||
|
mask[j] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void eval(const std::vector<array>& outputs, bool retain_graph /* = false */) {
|
||||||
|
if (!retain_graph) {
|
||||||
|
for (auto& out : outputs) {
|
||||||
|
if (out.has_primitive() && out.is_tracer()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[eval] Illegal to eval an array during "
|
||||||
|
"function transform without graph retention.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::function<void(const array&)> recurse;
|
||||||
|
std::queue<array> tape;
|
||||||
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
|
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
|
||||||
|
|
||||||
|
recurse = [&](const array& a) {
|
||||||
|
auto id = a.id();
|
||||||
|
if (cache.find(id) != cache.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (auto in : a.inputs()) {
|
||||||
|
recurse(in);
|
||||||
|
// If one of the inputs is being computed on a different
|
||||||
|
// stream, we need to manage the dependency.
|
||||||
|
if (!in.is_evaled()) {
|
||||||
|
if (a.primitive().stream() != in.primitive().stream()) {
|
||||||
|
deps.insert({in.id(), std::shared_future<void>{}});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cache.insert(id);
|
||||||
|
if (!a.is_evaled() || (!retain_graph && a.has_primitive())) {
|
||||||
|
if (!a.has_primitive()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[eval] Attempting to eval an array without a primitive.");
|
||||||
|
}
|
||||||
|
tape.push(a);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto& arr : outputs) {
|
||||||
|
if (!arr.is_evaled() || (!retain_graph && arr.has_primitive())) {
|
||||||
|
recurse(arr);
|
||||||
|
// Insert a dependency for every output to synchronize
|
||||||
|
// with at the end.
|
||||||
|
if (!arr.is_evaled()) {
|
||||||
|
deps.insert({arr.id(), std::shared_future<void>{}});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!tape.empty()) {
|
||||||
|
auto arr = std::move(tape.front());
|
||||||
|
tape.pop();
|
||||||
|
if (arr.is_evaled()) {
|
||||||
|
if (!retain_graph && arr.has_primitive()) {
|
||||||
|
arr.detach();
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto stream = arr.primitive().stream();
|
||||||
|
std::vector<std::shared_future<void>> arr_deps;
|
||||||
|
for (auto& in : arr.inputs()) {
|
||||||
|
if (auto it = deps.find(in.id()); it != deps.end()) {
|
||||||
|
arr_deps.push_back(it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::shared_ptr<std::promise<void>> p{nullptr};
|
||||||
|
if (auto it = deps.find(arr.id()); it != deps.end()) {
|
||||||
|
p = std::make_unique<std::promise<void>>();
|
||||||
|
it->second = p->get_future().share();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (arr.primitive().device() == Device::gpu) {
|
||||||
|
if (!metal::is_available()) {
|
||||||
|
throw std::runtime_error("Metal GPU is not available.");
|
||||||
|
}
|
||||||
|
scheduler::enqueue(
|
||||||
|
stream,
|
||||||
|
metal::make_task(
|
||||||
|
arr, std::move(arr_deps), std::move(p), retain_graph));
|
||||||
|
} else {
|
||||||
|
auto task = [retain_graph,
|
||||||
|
arr,
|
||||||
|
stream,
|
||||||
|
arr_deps = std::move(arr_deps),
|
||||||
|
p = std::move(p)]() mutable {
|
||||||
|
for (auto& d : arr_deps) {
|
||||||
|
d.wait();
|
||||||
|
}
|
||||||
|
scheduler::notify_new_task(stream);
|
||||||
|
arr.primitive().eval_cpu(arr.inputs(), arr);
|
||||||
|
if (!retain_graph) {
|
||||||
|
arr.detach();
|
||||||
|
}
|
||||||
|
if (p) {
|
||||||
|
p->set_value();
|
||||||
|
}
|
||||||
|
scheduler::notify_task_completion(stream);
|
||||||
|
};
|
||||||
|
scheduler::enqueue(stream, std::move(task));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto& arr : outputs) {
|
||||||
|
if (auto it = deps.find(arr.id()); it != deps.end()) {
|
||||||
|
it->second.wait();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotans) {
|
||||||
|
// Make tracers from given primals
|
||||||
|
std::vector<array> primals_;
|
||||||
|
for (auto& p : primals) {
|
||||||
|
auto s = p.has_primitive() ? p.primitive().stream()
|
||||||
|
: default_stream(default_device());
|
||||||
|
primals_.push_back(copy(p, s)); // Does not do a deep copy
|
||||||
|
primals_.back().set_tracer(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass tracer primals through the function
|
||||||
|
// Any variables that depend on the primals are marked as tracers
|
||||||
|
auto outputs = fun(primals_);
|
||||||
|
|
||||||
|
// Map outputs to passed cotans while ignoring the outputs
|
||||||
|
// that have stop_gradient called on them
|
||||||
|
int cotan_index = 0;
|
||||||
|
std::vector<std::pair<int, int>> output_cotan_pairs;
|
||||||
|
for (int i = 0; i < outputs.size(); ++i) {
|
||||||
|
auto& out = outputs[i];
|
||||||
|
if (out.has_primitive()) {
|
||||||
|
if (auto& p = out.primitive(); typeid(p) == typeid(StopGradient)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cotan_index >= cotans.size()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[vjp] Number of outputs with gradient does not match number of cotangents.");
|
||||||
|
}
|
||||||
|
if (out.shape() != cotans[cotan_index].shape()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[vjp] Output shape does not match shape of cotangent.");
|
||||||
|
}
|
||||||
|
output_cotan_pairs.emplace_back(i, cotan_index++);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topologically sort the compute graph, record outputs
|
||||||
|
// in the tape if a gradient is needed.
|
||||||
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
|
std::unordered_set<std::uintptr_t> calc_grad;
|
||||||
|
for (auto& primal : primals_) {
|
||||||
|
primal.set_tracer(false);
|
||||||
|
calc_grad.insert(primal.id());
|
||||||
|
cache.insert(primal.id());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> tape;
|
||||||
|
|
||||||
|
std::function<void(array&)> recurse;
|
||||||
|
recurse = [&](auto& a) {
|
||||||
|
auto id = a.id();
|
||||||
|
a.set_tracer(false);
|
||||||
|
|
||||||
|
// Check if visited and add to cache if not
|
||||||
|
if (auto inserted = cache.insert(id); !inserted.second) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& input : a.editable_inputs()) {
|
||||||
|
recurse(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop grad
|
||||||
|
if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate gradient if any inputs require gradient
|
||||||
|
for (auto& input : a.inputs()) {
|
||||||
|
if (calc_grad.find(input.id()) != calc_grad.end()) {
|
||||||
|
tape.push_back(a);
|
||||||
|
calc_grad.insert(id);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto& out : outputs) {
|
||||||
|
recurse(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the tape backwards, computing vector-jacobian
|
||||||
|
// products for each primitive
|
||||||
|
std::unordered_map<std::uintptr_t, array> cotan_map;
|
||||||
|
for (auto [out_idx, cotan_idx] : output_cotan_pairs) {
|
||||||
|
cotan_map.insert({outputs[out_idx].id(), cotans[cotan_idx]});
|
||||||
|
}
|
||||||
|
for (auto it = tape.rbegin(); it != tape.rend(); ++it) {
|
||||||
|
auto& a = *it;
|
||||||
|
|
||||||
|
// Get the arguments whose gradients are needed
|
||||||
|
std::vector<int> argnums;
|
||||||
|
for (int i = 0; i < a.inputs().size(); ++i) {
|
||||||
|
if (calc_grad.find(a.inputs()[i].id()) != calc_grad.end()) {
|
||||||
|
argnums.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cotan_it = cotan_map.find(a.id());
|
||||||
|
if (cotan_it == cotan_map.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cotan = cotan_map.extract(cotan_it).mapped();
|
||||||
|
auto vjps = a.primitive().vjp(a.inputs(), cotan, argnums);
|
||||||
|
auto s = a.primitive().stream();
|
||||||
|
// Accumulate the vector-jacobian products for each input
|
||||||
|
for (int i = 0; i < argnums.size(); ++i) {
|
||||||
|
auto in_id = a.inputs()[argnums[i]].id();
|
||||||
|
if (auto cotan_it = cotan_map.find(in_id); cotan_it != cotan_map.end()) {
|
||||||
|
cotan_it->second = add(cotan_it->second, vjps[i], s);
|
||||||
|
} else {
|
||||||
|
cotan_map.insert({in_id, vjps[i]});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> vjps;
|
||||||
|
for (auto& primal : primals_) {
|
||||||
|
if (auto cotan_it = cotan_map.find(primal.id());
|
||||||
|
cotan_it != cotan_map.end()) {
|
||||||
|
vjps.push_back(cotan_it->second);
|
||||||
|
} else {
|
||||||
|
auto s = primal.has_primitive() ? primal.primitive().stream()
|
||||||
|
: default_stream(default_device());
|
||||||
|
vjps.push_back(zeros_like(primal, s));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {outputs, vjps};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<array, array> vjp(
|
||||||
|
const std::function<array(const array&)>& fun,
|
||||||
|
const array& primal,
|
||||||
|
const array& cotan) {
|
||||||
|
auto vec_fun = [fun](const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{fun(inputs[0])};
|
||||||
|
};
|
||||||
|
auto [outputs, vjps] = vjp(vec_fun, {primal}, {cotan});
|
||||||
|
return {outputs[0], vjps[0]};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents) {
|
||||||
|
if (primals.size() != tangents.size()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[jvp] Number of inputs does not match number of tangents.");
|
||||||
|
}
|
||||||
|
for (int i = 0; i < primals.size(); ++i) {
|
||||||
|
if (primals[i].shape() != tangents[i].shape()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[jvp] Input shape does not match shape of tangent.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> primals_;
|
||||||
|
for (auto& p : primals) {
|
||||||
|
auto s = p.has_primitive() ? p.primitive().stream()
|
||||||
|
: default_stream(default_device());
|
||||||
|
primals_.push_back(copy(p, s)); // Does not do a deep copy
|
||||||
|
primals_.back().set_tracer(true);
|
||||||
|
}
|
||||||
|
auto outputs = fun(primals_);
|
||||||
|
|
||||||
|
// Topologically sort the compute graph, record outputs
|
||||||
|
// in the tape if a gradient is needed.
|
||||||
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
|
std::unordered_set<std::uintptr_t> calc_grad;
|
||||||
|
for (auto& primal : primals_) {
|
||||||
|
primal.set_tracer(false);
|
||||||
|
calc_grad.insert(primal.id());
|
||||||
|
cache.insert(primal.id());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> tape;
|
||||||
|
|
||||||
|
std::function<void(array&)> recurse;
|
||||||
|
recurse = [&](auto& a) {
|
||||||
|
auto id = a.id();
|
||||||
|
a.set_tracer(false);
|
||||||
|
|
||||||
|
// Check if visited and add to cache if not
|
||||||
|
if (auto inserted = cache.insert(id); !inserted.second) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& input : a.editable_inputs()) {
|
||||||
|
recurse(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop grad
|
||||||
|
if (a.has_primitive() && typeid(a.primitive()) == typeid(StopGradient)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate gradient if any inputs require gradient
|
||||||
|
for (auto& input : a.inputs()) {
|
||||||
|
if (calc_grad.find(input.id()) != calc_grad.end()) {
|
||||||
|
tape.push_back(a);
|
||||||
|
calc_grad.insert(id);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto& out : outputs) {
|
||||||
|
recurse(out);
|
||||||
|
}
|
||||||
|
std::unordered_map<std::uintptr_t, array> tan_map;
|
||||||
|
for (int i = 0; i < primals_.size(); ++i) {
|
||||||
|
tan_map.insert({primals_[i].id(), tangents[i]});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& a : tape) {
|
||||||
|
// Get the arguments used in the jvp
|
||||||
|
std::vector<int> argnums;
|
||||||
|
std::vector<array> tangents;
|
||||||
|
for (int i = 0; i < a.inputs().size(); ++i) {
|
||||||
|
if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) {
|
||||||
|
argnums.push_back(i);
|
||||||
|
tangents.push_back(it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto jvp = a.primitive().jvp(a.inputs(), tangents, argnums);
|
||||||
|
tan_map.insert({a.id(), jvp});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> jvps;
|
||||||
|
for (auto& out : outputs) {
|
||||||
|
if (auto it = tan_map.find(out.id()); it != tan_map.end()) {
|
||||||
|
jvps.push_back(it->second);
|
||||||
|
} else {
|
||||||
|
auto s = out.has_primitive() ? out.primitive().stream()
|
||||||
|
: default_stream(default_device());
|
||||||
|
jvps.push_back(zeros_like(out, s));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {outputs, jvps};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<array, array> jvp(
|
||||||
|
const std::function<array(const array&)>& fun,
|
||||||
|
const array& primal,
|
||||||
|
const array& tangent) {
|
||||||
|
auto vec_fun = [fun](const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{fun(inputs[0])};
|
||||||
|
};
|
||||||
|
auto [outputs, jvps] = jvp(vec_fun, {primal}, {tangent});
|
||||||
|
return {outputs[0], jvps[0]};
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueAndGradFn value_and_grad(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
if (argnums.empty()) {
|
||||||
|
throw std::invalid_argument("[grad] Must specify at least one argument.");
|
||||||
|
}
|
||||||
|
return [fun, argnums](const std::vector<array>& inputs) {
|
||||||
|
std::set<int> args;
|
||||||
|
for (auto& arg : argnums) {
|
||||||
|
args.insert(arg < 0 ? arg + inputs.size() : arg);
|
||||||
|
}
|
||||||
|
if (args.size() != argnums.size()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[grad] Repeat argument number not allowed in grad.");
|
||||||
|
}
|
||||||
|
if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[grad] Invalid argument number for function with "
|
||||||
|
<< inputs.size() << " inputs.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gfun = [&fun, &inputs, &args](const std::vector<array>& ginputs) {
|
||||||
|
std::vector<array> inputs_(inputs);
|
||||||
|
auto argit = args.begin();
|
||||||
|
for (int i = 0; i < ginputs.size(); ++i) {
|
||||||
|
inputs_[*argit] = ginputs[i];
|
||||||
|
++argit;
|
||||||
|
}
|
||||||
|
auto outputs = fun(inputs_);
|
||||||
|
for (int i = 1; i < outputs.size(); i++) {
|
||||||
|
auto& out = outputs[i];
|
||||||
|
auto s = out.has_primitive() ? out.primitive().stream()
|
||||||
|
: default_stream(default_device());
|
||||||
|
outputs[i] = stop_gradient(out, s);
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<array> ginputs;
|
||||||
|
for (auto arg : args) {
|
||||||
|
ginputs.push_back(inputs[arg]);
|
||||||
|
}
|
||||||
|
// Set the incoming gradient as int32 so that it will be promoted to the
|
||||||
|
// appropriate floating point type op(int, floatXX) -> floatXX for most ops
|
||||||
|
auto [outputs, grads] = vjp(gfun, ginputs, {array(1)});
|
||||||
|
return std::make_pair(outputs, grads);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& in_axes) {
|
||||||
|
if (in_axes.size() != inputs.size()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[vmap] The number of in axes must match the number of inputs.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the function on placeholder inputs
|
||||||
|
// to get the original graph
|
||||||
|
std::vector<array> s_inputs;
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
if (in_axes[i] != -1) {
|
||||||
|
if (inputs[i].ndim() == 0) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[vmap] Cannot vmap an input with zero dimensions.");
|
||||||
|
}
|
||||||
|
if (in_axes[i] > inputs[i].ndim()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[vmap] Axis " << in_axes[i] << " invalid for input with "
|
||||||
|
<< inputs[i].ndim() << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> shape = inputs[i].shape();
|
||||||
|
shape.erase(shape.begin() + in_axes[i]);
|
||||||
|
array in(shape, inputs[i].dtype(), nullptr, {});
|
||||||
|
s_inputs.push_back(in);
|
||||||
|
s_inputs.back().set_tracer(true);
|
||||||
|
} else {
|
||||||
|
s_inputs.push_back(inputs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {s_inputs, fun(s_inputs)};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> vmap_replace(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& s_inputs,
|
||||||
|
const std::vector<array>& s_outputs,
|
||||||
|
const std::vector<int>& in_axes,
|
||||||
|
const std::vector<int>& out_axes) {
|
||||||
|
if (out_axes.size() != s_outputs.size()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[vmap] The number of out axes must match the number of outputs.");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
|
||||||
|
std::unordered_set<std::uintptr_t> needs_vmap;
|
||||||
|
for (int i = 0; i < s_inputs.size(); ++i) {
|
||||||
|
if (in_axes[i] != -1) {
|
||||||
|
tmap.insert({s_inputs[i].id(), {inputs[i], in_axes[i]}});
|
||||||
|
needs_vmap.insert(s_inputs[i].id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topologically sort the graph
|
||||||
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
|
for (int i = 0; i < s_inputs.size(); ++i) {
|
||||||
|
auto in = s_inputs[i];
|
||||||
|
if (in_axes[i] != -1) {
|
||||||
|
in.set_tracer(false);
|
||||||
|
}
|
||||||
|
cache.insert(in.id());
|
||||||
|
}
|
||||||
|
std::vector<array> tape;
|
||||||
|
|
||||||
|
std::function<void(const array&)> recurse;
|
||||||
|
|
||||||
|
recurse = [&](const array& a) {
|
||||||
|
// Stop at inputs to the vmap function
|
||||||
|
auto id = a.id();
|
||||||
|
if (cache.find(id) != cache.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (auto& input : a.inputs()) {
|
||||||
|
recurse(input);
|
||||||
|
}
|
||||||
|
cache.insert(id);
|
||||||
|
for (auto& input : a.inputs()) {
|
||||||
|
if (needs_vmap.find(input.id()) != needs_vmap.end()) {
|
||||||
|
needs_vmap.insert(id);
|
||||||
|
tape.push_back(a);
|
||||||
|
tape.back().set_tracer(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto& out : s_outputs) {
|
||||||
|
recurse(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transform each primitive in the graph with
|
||||||
|
// its vmap implementation
|
||||||
|
for (auto& a : tape) {
|
||||||
|
std::vector<array> v_inputs;
|
||||||
|
std::vector<int> v_axes;
|
||||||
|
for (auto& in : a.inputs()) {
|
||||||
|
auto map_it = tmap.find(in.id());
|
||||||
|
if (map_it != tmap.end()) {
|
||||||
|
v_inputs.push_back(map_it->second.first);
|
||||||
|
v_axes.push_back(map_it->second.second);
|
||||||
|
} else {
|
||||||
|
v_inputs.push_back(in);
|
||||||
|
v_axes.push_back(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto out_and_axis = a.primitive().vmap(v_inputs, v_axes);
|
||||||
|
tmap.insert({a.id(), out_and_axis});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate the outputs and make sure all the output axes are
|
||||||
|
// in the right place
|
||||||
|
std::vector<array> outputs;
|
||||||
|
for (int i = 0; i < s_outputs.size(); ++i) {
|
||||||
|
auto map_it = tmap.find(s_outputs[i].id());
|
||||||
|
if (map_it != tmap.end()) {
|
||||||
|
auto& [out, vdim] = map_it->second;
|
||||||
|
if (vdim != out_axes[i]) {
|
||||||
|
if (out_axes[i] >= out.ndim()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[vmap] Axis " << out_axes[i] << " invalid for output with "
|
||||||
|
<< out.ndim() << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
std::vector<int> reorder(out.ndim());
|
||||||
|
std::iota(reorder.begin(), reorder.end(), 0);
|
||||||
|
reorder.erase(reorder.begin() + vdim);
|
||||||
|
reorder.insert(reorder.begin() + out_axes[i], vdim);
|
||||||
|
out = transpose(out, reorder);
|
||||||
|
}
|
||||||
|
outputs.push_back(out);
|
||||||
|
} else {
|
||||||
|
outputs.push_back(s_outputs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> vmap(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<int>& in_axes /* = {} */,
|
||||||
|
const std::vector<int>& out_axes /* = {} */) {
|
||||||
|
auto infer_axes = [](auto axes) {
|
||||||
|
return !axes.empty() &&
|
||||||
|
std::all_of(axes.begin(), axes.end(), [](int ax) { return ax < 0; });
|
||||||
|
};
|
||||||
|
if (infer_axes(in_axes) != infer_axes(out_axes)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[vmap] Input (or output) axes must be "
|
||||||
|
"specified if output (or input) axes are.");
|
||||||
|
}
|
||||||
|
auto vfun = [fun, in_axes = in_axes, out_axes = out_axes](
|
||||||
|
const std::vector<array>& inputs) mutable {
|
||||||
|
if (in_axes.size() == 0) {
|
||||||
|
in_axes.resize(inputs.size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [trace_inputs, trace_outputs] =
|
||||||
|
detail::vmap_trace(fun, inputs, in_axes);
|
||||||
|
|
||||||
|
if (out_axes.size() == 0) {
|
||||||
|
out_axes.resize(trace_outputs.size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return detail::vmap_replace(
|
||||||
|
inputs, trace_inputs, trace_outputs, in_axes, out_axes);
|
||||||
|
};
|
||||||
|
|
||||||
|
return vfun;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<array(const array&, const array&)> vmap(
|
||||||
|
const std::function<array(const array&, const array&)>& fun,
|
||||||
|
int in_axis_a /* = 0 */,
|
||||||
|
int in_axis_b /* = 0 */,
|
||||||
|
int out_axis /* = 0 */) {
|
||||||
|
auto vfun = vmap(
|
||||||
|
[in_axis_a, in_axis_b, out_axis, fun](const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{fun(inputs[0], inputs[1])};
|
||||||
|
},
|
||||||
|
{in_axis_a, in_axis_b},
|
||||||
|
{out_axis});
|
||||||
|
return [vfun](const array& a, const array& b) { return vfun({a, b})[0]; };
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<array(const array&)> vmap(
|
||||||
|
const std::function<array(const array&)>& fun,
|
||||||
|
int in_axis /* = 0 */,
|
||||||
|
int out_axis /* = 0 */) {
|
||||||
|
auto vfun = vmap(
|
||||||
|
[in_axis, out_axis, fun](const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{fun(inputs[0])};
|
||||||
|
},
|
||||||
|
{in_axis},
|
||||||
|
{out_axis});
|
||||||
|
return [vfun](const array& a) { return vfun({a})[0]; };
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
185
mlx/transforms.h
Normal file
185
mlx/transforms.h
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "array.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
/** Fuse equivalent arrays to avoid duplicate execution. */
|
||||||
|
void simplify(const std::vector<array>& outputs);
|
||||||
|
|
||||||
|
template <typename... Arrays>
|
||||||
|
void simplify(Arrays... outputs) {
|
||||||
|
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
|
}
|
||||||
|
|
||||||
|
void eval(const std::vector<array>& outputs, bool retain_graph = false);
|
||||||
|
|
||||||
|
template <typename... Arrays>
|
||||||
|
void eval(Arrays... outputs) {
|
||||||
|
eval(std::vector<array>{std::forward<Arrays>(outputs)...}, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the output and vector-Jacobian product (VJP) of a function.
|
||||||
|
*
|
||||||
|
* Computes the vector-Jacobian product of the vector of cotangents with the
|
||||||
|
* Jacobian of the function evaluated at the primals. Returns a pair of
|
||||||
|
* vectors of output arrays and VJP arrays.
|
||||||
|
**/
|
||||||
|
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the output and vector-Jacobian product (VJP) of a unary function.
|
||||||
|
*/
|
||||||
|
std::pair<array, array> vjp(
|
||||||
|
const std::function<array(const array&)>& fun,
|
||||||
|
const array& primal,
|
||||||
|
const array& cotangent);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the output and Jacobian-vector product (JVP) of a function.
|
||||||
|
*
|
||||||
|
* Computes the Jacobian-vector product of the Jacobian of the function
|
||||||
|
* evaluated at the primals with the vector of tangents. Returns a pair of
|
||||||
|
* vectors of output arrays and JVP arrays.
|
||||||
|
**/
|
||||||
|
std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the output and Jacobian-vector product (JVP) of a unary function.
|
||||||
|
*/
|
||||||
|
std::pair<array, array> jvp(
|
||||||
|
const std::function<array(const array&)>& fun,
|
||||||
|
const array& primal,
|
||||||
|
const array& tangent);
|
||||||
|
|
||||||
|
// Return type of general value_and_grad: a function which takes an input
|
||||||
|
// vector of arrays and returns a pair of vectors of arrays one for the
|
||||||
|
// values and one for the gradients wrt the first value.
|
||||||
|
using ValueAndGradFn =
|
||||||
|
std::function<std::pair<std::vector<array>, std::vector<array>>(
|
||||||
|
const std::vector<array>&)>;
|
||||||
|
using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
|
||||||
|
const std::vector<array>&)>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a function which computes the value and gradient of the input
|
||||||
|
* function with respect to a vector of input arrays.
|
||||||
|
**/
|
||||||
|
ValueAndGradFn value_and_grad(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<int>& argnums);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a function which computes the value and gradient of the input
|
||||||
|
* function with repsect to a single input array.
|
||||||
|
**/
|
||||||
|
ValueAndGradFn inline value_and_grad(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
int argnum = 0) {
|
||||||
|
return value_and_grad(fun, std::vector<int>{argnum});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a function which computes the value and gradient of the unary
|
||||||
|
* input function.
|
||||||
|
**/
|
||||||
|
std::function<std::pair<array, array>(const array&)> inline value_and_grad(
|
||||||
|
const std::function<array(const array&)>& fun) {
|
||||||
|
return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
|
||||||
|
}
|
||||||
|
|
||||||
|
SimpleValueAndGradFn inline value_and_grad(
|
||||||
|
const std::function<array(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
return [fun, argnums](auto inputs) {
|
||||||
|
auto result = value_and_grad(
|
||||||
|
[fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
|
||||||
|
argnums)(inputs);
|
||||||
|
|
||||||
|
return std::make_pair(result.first[0], result.second);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
SimpleValueAndGradFn inline value_and_grad(
|
||||||
|
const std::function<array(const std::vector<array>&)>& fun,
|
||||||
|
int argnum = 0) {
|
||||||
|
return value_and_grad(fun, std::vector<int>{argnum});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a function which computes the gradient of the input function with
|
||||||
|
* respect to a vector of input arrays.
|
||||||
|
*
|
||||||
|
* The function being differentiated takes a vector of arrays and returns an
|
||||||
|
* array. The vector of `argnums` specifies which the arguments to compute
|
||||||
|
* the gradient with respect to. At least one argument must be specified.
|
||||||
|
**/
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
|
||||||
|
const std::function<array(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
auto fn = value_and_grad(fun, argnums);
|
||||||
|
return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a function which computes the gradient of the input function with
|
||||||
|
* repsect to a single input array.
|
||||||
|
*
|
||||||
|
* The function being differentiated takes a vector of arrays and returns an
|
||||||
|
* array. The optional `argnum` index specifies which the argument to compute
|
||||||
|
* the gradient with respect to and defaults to 0.
|
||||||
|
**/
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
|
||||||
|
const std::function<array(const std::vector<array>&)>& fun,
|
||||||
|
int argnum = 0) {
|
||||||
|
return grad(fun, std::vector<int>{argnum});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a function which computes the gradient of the unary input function.
|
||||||
|
**/
|
||||||
|
std::function<array(const array&)> inline grad(
|
||||||
|
const std::function<array(const array&)>& fun) {
|
||||||
|
auto fn = value_and_grad(fun);
|
||||||
|
return [fn](const array& input) { return fn(input).second; };
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatically vectorize a unary function over the requested axes.
|
||||||
|
*/
|
||||||
|
std::function<array(const array&)> vmap(
|
||||||
|
const std::function<array(const array&)>& fun,
|
||||||
|
int in_axis = 0,
|
||||||
|
int out_axis = 0);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatically vectorize a binary function over the requested axes.
|
||||||
|
*/
|
||||||
|
std::function<array(const array&, const array&)> vmap(
|
||||||
|
const std::function<array(const array&, const array&)>& fun,
|
||||||
|
int in_axis_a = 0,
|
||||||
|
int in_axis_b = 0,
|
||||||
|
int out_axis = 0);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatically vectorize a function over the requested axes.
|
||||||
|
*
|
||||||
|
* The input function to `vmap` takes as an argument a vector of arrays and
|
||||||
|
* returns a vector of arrays. Optionally specify the axes to vectorize over
|
||||||
|
* with `in_axes` and `out_axes`, otherwise a default of 0 is used.
|
||||||
|
* Returns a vectorized function with the same signature as the input
|
||||||
|
* function.
|
||||||
|
*/
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> vmap(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<int>& in_axes = {},
|
||||||
|
const std::vector<int>& out_axes = {});
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
185
mlx/types/bf16.h
Normal file
185
mlx/types/bf16.h
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define __MLX_BFLOAT_NAN__ 0x7FC0
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
union float_bits_bf16 {
|
||||||
|
float f;
|
||||||
|
uint32_t u;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
struct _MLX_BFloat16 {
|
||||||
|
uint16_t bits_;
|
||||||
|
|
||||||
|
// Default constructor
|
||||||
|
_MLX_BFloat16() = default;
|
||||||
|
|
||||||
|
// Default copy constructor
|
||||||
|
_MLX_BFloat16(_MLX_BFloat16 const&) = default;
|
||||||
|
|
||||||
|
// Appease std::vector<bool> for being special
|
||||||
|
_MLX_BFloat16& operator=(std::vector<bool>::reference x) {
|
||||||
|
bits_ = x;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
_MLX_BFloat16& operator=(const float& x) {
|
||||||
|
return (*this = _MLX_BFloat16(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
// From float32
|
||||||
|
_MLX_BFloat16(const float& x) {
|
||||||
|
if (std::isnan(x)) {
|
||||||
|
bits_ = __MLX_BFLOAT_NAN__;
|
||||||
|
} else {
|
||||||
|
// Union
|
||||||
|
float_bits_bf16 in;
|
||||||
|
|
||||||
|
// Take bits
|
||||||
|
in.f = x;
|
||||||
|
|
||||||
|
// Round to nearest even
|
||||||
|
in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF);
|
||||||
|
|
||||||
|
// Take upper 16 bits
|
||||||
|
bits_ = in.u >> 16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// To float32
|
||||||
|
operator float() const {
|
||||||
|
// Union
|
||||||
|
float_bits_bf16 out;
|
||||||
|
|
||||||
|
// Upper 16 bits are the data and lower 16 bits are 0s
|
||||||
|
out.u = ((uint32_t)bits_) << 16;
|
||||||
|
|
||||||
|
return out.f;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
||||||
|
inline otype __operator__(atype lhs, btype rhs) { \
|
||||||
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
||||||
|
inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
||||||
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||||
|
} \
|
||||||
|
inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
||||||
|
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operators
|
||||||
|
#define bfloat_binop(_op_, _operator_) \
|
||||||
|
bfloat_binop_base( \
|
||||||
|
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, double, double, double); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
||||||
|
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
||||||
|
|
||||||
|
bfloat_binop(+, operator+);
|
||||||
|
bfloat_binop(-, operator-);
|
||||||
|
bfloat_binop(*, operator*);
|
||||||
|
bfloat_binop(/, operator/);
|
||||||
|
|
||||||
|
#undef bfloat_binop
|
||||||
|
|
||||||
|
// Comparison ops
|
||||||
|
#define bfloat_compop(__op__, __operator__) \
|
||||||
|
bfloat_binop_base( \
|
||||||
|
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, double, double); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
||||||
|
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
||||||
|
|
||||||
|
bfloat_compop(>, operator>);
|
||||||
|
bfloat_compop(<, operator<);
|
||||||
|
bfloat_compop(>=, operator>=);
|
||||||
|
bfloat_compop(<=, operator<=);
|
||||||
|
bfloat_compop(==, operator==);
|
||||||
|
bfloat_compop(!=, operator!=);
|
||||||
|
|
||||||
|
#undef bfloat_compop
|
||||||
|
|
||||||
|
// Negative
|
||||||
|
inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) {
|
||||||
|
return -static_cast<float>(lhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inplace ops
|
||||||
|
#define bfloat_inplace_op(__op__, __operator__) \
|
||||||
|
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \
|
||||||
|
lhs = lhs __op__ rhs; \
|
||||||
|
return lhs; \
|
||||||
|
} \
|
||||||
|
inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \
|
||||||
|
lhs = lhs __op__ rhs; \
|
||||||
|
return lhs; \
|
||||||
|
}
|
||||||
|
|
||||||
|
bfloat_inplace_op(+, operator+=);
|
||||||
|
bfloat_inplace_op(-, operator-=);
|
||||||
|
bfloat_inplace_op(*, operator*=);
|
||||||
|
bfloat_inplace_op(/, operator/=);
|
||||||
|
|
||||||
|
#undef bfloat_inplace_op
|
||||||
|
|
||||||
|
// Bitwise ops
|
||||||
|
|
||||||
|
#define bfloat_bitop(__op__, __operator__) \
|
||||||
|
inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \
|
||||||
|
_MLX_BFloat16 out; \
|
||||||
|
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
||||||
|
return out; \
|
||||||
|
} \
|
||||||
|
inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \
|
||||||
|
_MLX_BFloat16 out; \
|
||||||
|
out.bits_ = lhs.bits_ __op__ rhs; \
|
||||||
|
return out; \
|
||||||
|
} \
|
||||||
|
inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \
|
||||||
|
_MLX_BFloat16 out; \
|
||||||
|
out.bits_ = lhs __op__ rhs.bits_; \
|
||||||
|
return out; \
|
||||||
|
}
|
||||||
|
|
||||||
|
bfloat_bitop(|, operator|);
|
||||||
|
bfloat_bitop(&, operator&);
|
||||||
|
bfloat_bitop(^, operator^);
|
||||||
|
|
||||||
|
#undef bfloat_bitop
|
||||||
|
|
||||||
|
#define bfloat_inplace_bitop(__op__, __operator__) \
|
||||||
|
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
||||||
|
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
||||||
|
return lhs; \
|
||||||
|
} \
|
||||||
|
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \
|
||||||
|
lhs.bits_ = lhs.bits_ __op__ rhs; \
|
||||||
|
return lhs; \
|
||||||
|
}
|
||||||
|
|
||||||
|
bfloat_inplace_bitop(|, operator|=);
|
||||||
|
bfloat_inplace_bitop(&, operator&=);
|
||||||
|
bfloat_inplace_bitop(^, operator^=);
|
||||||
|
|
||||||
|
#undef bfloat_inplace_bitop
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
3
python/mlx/nn/__init__.py
Normal file
3
python/mlx/nn/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from mlx.nn.layers import *
|
||||||
|
from mlx.nn import losses
|
||||||
|
from mlx.nn.utils import value_and_grad
|
23
python/mlx/nn/layers/__init__.py
Normal file
23
python/mlx/nn/layers/__init__.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from mlx.nn.layers.base import Module
|
||||||
|
from mlx.nn.layers.activations import (
|
||||||
|
GELU,
|
||||||
|
ReLU,
|
||||||
|
SiLU,
|
||||||
|
gelu,
|
||||||
|
gelu_approx,
|
||||||
|
gelu_fast_approx,
|
||||||
|
relu,
|
||||||
|
silu,
|
||||||
|
)
|
||||||
|
from mlx.nn.layers.containers import Sequential
|
||||||
|
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||||
|
from mlx.nn.layers.dropout import Dropout
|
||||||
|
from mlx.nn.layers.embedding import Embedding
|
||||||
|
from mlx.nn.layers.linear import Linear
|
||||||
|
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
|
||||||
|
from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding
|
||||||
|
from mlx.nn.layers.transformer import (
|
||||||
|
MultiHeadAttention,
|
||||||
|
TransformerEncoder,
|
||||||
|
TransformerEncoderLayer,
|
||||||
|
)
|
129
python/mlx/nn/layers/activations.py
Normal file
129
python/mlx/nn/layers/activations.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.nn.layers.base import Module
|
||||||
|
|
||||||
|
|
||||||
|
def _make_activation_module(f):
|
||||||
|
def decorator(klass):
|
||||||
|
klass.__doc__ = f.__doc__
|
||||||
|
klass.__call__ = lambda self, x: f(x)
|
||||||
|
return klass
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def relu(x):
|
||||||
|
"""Applies the Rectified Linear Unit.
|
||||||
|
|
||||||
|
Simply ``mx.maximum(x, 0)``.
|
||||||
|
"""
|
||||||
|
return mx.maximum(x, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def silu(x):
|
||||||
|
r"""Applies the Sigmoid Linear Unit.
|
||||||
|
|
||||||
|
Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is
|
||||||
|
the logistic sigmoid.
|
||||||
|
"""
|
||||||
|
return x * mx.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def gelu(x):
|
||||||
|
"""Applies the Gaussian Error Linear Units function.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\\textrm{GELU}(x) = x * \Phi(x)
|
||||||
|
|
||||||
|
where :math:`\Phi(x)` is the Gaussian CDF.
|
||||||
|
|
||||||
|
See also :func:`gelu_approx` and :func:`gelu_fast_approx` for faster
|
||||||
|
approximations.
|
||||||
|
"""
|
||||||
|
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||||
|
|
||||||
|
|
||||||
|
def gelu_approx(x):
|
||||||
|
r"""An approximation to Gaussian Error Linear Unit.
|
||||||
|
|
||||||
|
See :func:`gelu` for the exact computation.
|
||||||
|
|
||||||
|
This function approximates ``gelu`` with a maximum absolute error :math:`<
|
||||||
|
0.0003` in the range :math:`[-6, 6]` using the following
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
x = x \sigma\left(1.60033 x \left(1 + 0.0433603 x^2\right)\right)
|
||||||
|
|
||||||
|
where :math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||||
|
"""
|
||||||
|
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
|
||||||
|
|
||||||
|
|
||||||
|
def gelu_fast_approx(x):
|
||||||
|
r"""A fast approximation to Gaussian Error Linear Unit.
|
||||||
|
|
||||||
|
See :func:`gelu` for the exact computation.
|
||||||
|
|
||||||
|
This function approximates ``gelu`` with a maximum absolute error :math:`<
|
||||||
|
0.015` in the range :math:`[-6, 6]` using the following
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
x = x \sigma\left(1.773 x\right)
|
||||||
|
|
||||||
|
where :math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||||
|
"""
|
||||||
|
return x * mx.sigmoid(1.773 * x)
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(relu)
|
||||||
|
class ReLU(Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(silu)
|
||||||
|
class SiLU(Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(Module):
|
||||||
|
r"""Applies the Gaussian Error Linear Units.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\textrm{GELU}(x) = x * \Phi(x)
|
||||||
|
|
||||||
|
where :math:`\Phi(x)` is the Gaussian CDF.
|
||||||
|
|
||||||
|
However, if ``approx`` is set to 'precise' or 'fast' it applies
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\textrm{GELUApprox}(x) &= x * \sigma\left(1.60033 * x \left(1 + 0.0433603 * x^2\right)\right) \\
|
||||||
|
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
|
||||||
|
|
||||||
|
respectively.
|
||||||
|
|
||||||
|
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
|
||||||
|
functional equivalents and information regarding error bounds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, approx="none"):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if approx == "none":
|
||||||
|
self._act = gelu
|
||||||
|
elif approx == "precise":
|
||||||
|
self._act = gelu_approx
|
||||||
|
elif approx == "fast":
|
||||||
|
self._act = gelu_fast_approx
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self._act(x)
|
6
python/mlx/nn/losses.py
Normal file
6
python/mlx/nn/losses.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1):
|
||||||
|
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
|
||||||
|
return mx.logsumexp(logits, axis=axis) - score
|
136
python/mlx/utils.py
Normal file
136
python/mlx/utils.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
def tree_map(fn, tree, *rest):
|
||||||
|
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
||||||
|
returns a new collection with the results.
|
||||||
|
|
||||||
|
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
|
||||||
|
and the corresponding leaves are provided as extra positional arguments to
|
||||||
|
``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
|
||||||
|
than to :func:`map`.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.utils import tree_map
|
||||||
|
|
||||||
|
model = nn.Linear(10, 10)
|
||||||
|
print(model.parameters().keys())
|
||||||
|
# dict_keys(['weight', 'bias'])
|
||||||
|
|
||||||
|
# square the parameters
|
||||||
|
model.update(tree_map(lambda x: x*x, model.parameters()))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn (Callable): The function that processes the leaves of the tree
|
||||||
|
tree (Any): The main python tree that will be iterated upon
|
||||||
|
rest (Tuple[Any]): Extra trees to be iterated together with tree
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A python tree with the new values returned by ``fn``.
|
||||||
|
"""
|
||||||
|
if isinstance(tree, list):
|
||||||
|
return [
|
||||||
|
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
|
||||||
|
]
|
||||||
|
elif isinstance(tree, tuple):
|
||||||
|
return tuple(
|
||||||
|
tree_map(fn, child, *(r[i] for r in rest)) for i, child in enumerate(tree)
|
||||||
|
)
|
||||||
|
elif isinstance(tree, dict):
|
||||||
|
return {
|
||||||
|
k: tree_map(fn, child, *(r[k] for r in rest)) for k, child in tree.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return fn(tree, *rest)
|
||||||
|
|
||||||
|
|
||||||
|
def tree_flatten(tree, prefix="", is_leaf=None):
|
||||||
|
"""Flattens a python tree to a list of key, value tuples.
|
||||||
|
|
||||||
|
The keys are using the dot notation to define trees of arbitrary depth and
|
||||||
|
complexity.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
print(tree_flatten([[[0]]]))
|
||||||
|
# [("0.0.0", 0)]
|
||||||
|
|
||||||
|
print(tree_flatten([[[0]]], ".hello"))
|
||||||
|
# [("hello.0.0.0", 0)]
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Dictionaries should have keys that are valid python identifiers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree (Any): The python tree to be flattened.
|
||||||
|
prefix (str): A prefix to use for the keys. The first character is
|
||||||
|
always discarded.
|
||||||
|
is_leaf (Callable): An optional callable that returns True if the
|
||||||
|
passed object is considered a leaf or False otherwise.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[str, Any]]: The flat representation of the python tree.
|
||||||
|
"""
|
||||||
|
flat_tree = []
|
||||||
|
|
||||||
|
if is_leaf is None or not is_leaf(tree):
|
||||||
|
if isinstance(tree, (list, tuple)):
|
||||||
|
for i, t in enumerate(tree):
|
||||||
|
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
|
||||||
|
return flat_tree
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
for k, t in tree.items():
|
||||||
|
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
|
||||||
|
return flat_tree
|
||||||
|
|
||||||
|
return [(prefix[1:], tree)]
|
||||||
|
|
||||||
|
|
||||||
|
def tree_unflatten(tree):
|
||||||
|
"""Recreate a python tree from its flat representation.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_unflatten
|
||||||
|
|
||||||
|
d = tree_unflatten([("hello.world", 42)])
|
||||||
|
print(d)
|
||||||
|
# {"hello": {"world": 42}}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree (List[Tuple[str, Any]]): The flat representation of a python tree.
|
||||||
|
For instance as returned by :meth:`tree_flatten`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A python tree.
|
||||||
|
"""
|
||||||
|
if len(tree) == 1 and tree[0][0] == "":
|
||||||
|
return tree[0][1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
int(tree[0][0].split(".", maxsplit=1)[0])
|
||||||
|
is_list = True
|
||||||
|
except ValueError:
|
||||||
|
is_list = False
|
||||||
|
|
||||||
|
# collect children
|
||||||
|
children = {}
|
||||||
|
for key, value in tree:
|
||||||
|
current_idx, *next_idx = key.split(".", maxsplit=1)
|
||||||
|
next_idx = "" if not next_idx else next_idx[0]
|
||||||
|
if current_idx not in children:
|
||||||
|
children[current_idx] = []
|
||||||
|
children[current_idx].append((next_idx, value))
|
||||||
|
|
||||||
|
# recursively map them to the original container
|
||||||
|
if is_list:
|
||||||
|
keys = sorted((int(idx), idx) for idx in children.keys())
|
||||||
|
l = []
|
||||||
|
for i, k in keys:
|
||||||
|
while i > len(l):
|
||||||
|
l.append({})
|
||||||
|
l.append(tree_unflatten(children[k]))
|
||||||
|
return l
|
||||||
|
else:
|
||||||
|
return {k: tree_unflatten(v) for k, v in children.items()}
|
1071
python/src/array.cpp
Normal file
1071
python/src/array.cpp
Normal file
File diff suppressed because it is too large
Load Diff
42
python/src/device.cpp
Normal file
42
python/src/device.cpp
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
|
||||||
|
#include "mlx/device.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
using namespace py::literals;
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
void init_device(py::module_& m) {
|
||||||
|
py::enum_<Device::DeviceType>(m, "DeviceType")
|
||||||
|
.value("cpu", Device::DeviceType::cpu)
|
||||||
|
.value("gpu", Device::DeviceType::gpu)
|
||||||
|
.export_values()
|
||||||
|
.def(
|
||||||
|
"__eq__",
|
||||||
|
[](const Device::DeviceType& d1, const Device& d2) {
|
||||||
|
return d1 == d2;
|
||||||
|
},
|
||||||
|
py::prepend());
|
||||||
|
|
||||||
|
py::class_<Device>(m, "Device")
|
||||||
|
.def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
|
||||||
|
.def_readonly("type", &Device::type)
|
||||||
|
.def(
|
||||||
|
"__repr__",
|
||||||
|
[](const Device& d) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << d;
|
||||||
|
return os.str();
|
||||||
|
})
|
||||||
|
.def("__eq__", [](const Device& d1, const Device& d2) {
|
||||||
|
return d1 == d2;
|
||||||
|
});
|
||||||
|
|
||||||
|
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||||
|
|
||||||
|
m.def("default_device", &default_device);
|
||||||
|
m.def("set_default_device", &set_default_device, "device"_a);
|
||||||
|
}
|
12
python/src/metal.cpp
Normal file
12
python/src/metal.cpp
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
void init_metal(py::module_& m) {
|
||||||
|
py::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||||
|
metal.def("is_available", &metal::is_available);
|
||||||
|
}
|
32
python/src/stream.cpp
Normal file
32
python/src/stream.cpp
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
|
||||||
|
#include "mlx/stream.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
using namespace py::literals;
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
void init_stream(py::module_& m) {
|
||||||
|
py::class_<Stream>(m, "Stream")
|
||||||
|
.def(py::init<int, Device>(), "index"_a, "device"_a)
|
||||||
|
.def_readonly("device", &Stream::device)
|
||||||
|
.def(
|
||||||
|
"__repr__",
|
||||||
|
[](const Stream& s) {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << s;
|
||||||
|
return os.str();
|
||||||
|
})
|
||||||
|
.def("__eq__", [](const Stream& s1, const Stream& s2) {
|
||||||
|
return s1 == s2;
|
||||||
|
});
|
||||||
|
|
||||||
|
py::implicitly_convertible<Device::DeviceType, Device>();
|
||||||
|
|
||||||
|
m.def("default_stream", &default_stream, "device"_a);
|
||||||
|
m.def("set_default_stream", &set_default_stream, "stream"_a);
|
||||||
|
m.def("new_stream", &new_stream, "device"_a);
|
||||||
|
}
|
188
python/tests/test_bf16.py
Normal file
188
python/tests/test_bf16.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
import unittest
|
||||||
|
from itertools import permutations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mlx_tests
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
has_torch = True
|
||||||
|
except ImportError as e:
|
||||||
|
has_torch = False
|
||||||
|
|
||||||
|
|
||||||
|
class TestBF16(mlx_tests.MLXTestCase):
|
||||||
|
def __test_ops(
|
||||||
|
self,
|
||||||
|
ref_op, # Function that outputs array_like
|
||||||
|
mlx_op, # Function that outputs array_like
|
||||||
|
np_args, # Numpy arguments
|
||||||
|
ref_transform=lambda x: x,
|
||||||
|
mlx_transform=lambda x: mx.array(x),
|
||||||
|
atol=1e-5,
|
||||||
|
):
|
||||||
|
ref_args = map(ref_transform, np_args)
|
||||||
|
mlx_args = map(mlx_transform, np_args)
|
||||||
|
|
||||||
|
r_ref = ref_op(*ref_args)
|
||||||
|
r_mlx = mlx_op(*mlx_args)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(r_mlx, r_ref, atol=atol))
|
||||||
|
|
||||||
|
def __default_test(
|
||||||
|
self,
|
||||||
|
op,
|
||||||
|
np_args,
|
||||||
|
simple_transform=lambda x: x,
|
||||||
|
atol_np=1e-3,
|
||||||
|
atol_torch=1e-5,
|
||||||
|
np_kwargs=dict(),
|
||||||
|
mlx_kwargs=dict(),
|
||||||
|
torch_kwargs=dict(),
|
||||||
|
torch_op=None,
|
||||||
|
):
|
||||||
|
with self.subTest(reference="numpy"):
|
||||||
|
|
||||||
|
def np_transform(x):
|
||||||
|
x_mx_bf16 = mx.array(x).astype(mx.bfloat16)
|
||||||
|
x_mx_fp32 = x_mx_bf16.astype(mx.float32)
|
||||||
|
return np.asarray(x_mx_fp32)
|
||||||
|
|
||||||
|
def mlx_fn(*args):
|
||||||
|
out_bf16 = getattr(mx, op)(*args, **mlx_kwargs)
|
||||||
|
return np.asarray(out_bf16.astype(mx.float32))
|
||||||
|
|
||||||
|
def np_fn(*args):
|
||||||
|
out_fp32 = getattr(np, op)(*args, **np_kwargs)
|
||||||
|
return np_transform(out_fp32)
|
||||||
|
|
||||||
|
ref_op = np_fn
|
||||||
|
mlx_op = mlx_fn
|
||||||
|
|
||||||
|
ref_transform = lambda x: simple_transform(np_transform(x))
|
||||||
|
mlx_transform = lambda x: simple_transform(mx.array(x).astype(mx.bfloat16))
|
||||||
|
|
||||||
|
self.__test_ops(
|
||||||
|
ref_op,
|
||||||
|
mlx_op,
|
||||||
|
np_args,
|
||||||
|
ref_transform=ref_transform,
|
||||||
|
mlx_transform=mlx_transform,
|
||||||
|
atol=atol_np,
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_torch:
|
||||||
|
with self.subTest(reference="torch"):
|
||||||
|
torch_op = op if torch_op is None else torch_op
|
||||||
|
|
||||||
|
def torch_fn(*args):
|
||||||
|
out_bf16 = getattr(torch, torch_op)(*args, **torch_kwargs)
|
||||||
|
return out_bf16.to(torch.float32).numpy()
|
||||||
|
|
||||||
|
ref_op = torch_fn
|
||||||
|
ref_transform = lambda x: simple_transform(
|
||||||
|
torch.from_numpy(x).to(torch.bfloat16)
|
||||||
|
)
|
||||||
|
self.__test_ops(
|
||||||
|
ref_op,
|
||||||
|
mlx_op,
|
||||||
|
np_args,
|
||||||
|
ref_transform=ref_transform,
|
||||||
|
mlx_transform=mlx_transform,
|
||||||
|
atol=atol_torch,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_unary_ops(self):
|
||||||
|
x = np.random.rand(18, 28, 38)
|
||||||
|
for op in ["abs", "exp", "log", "square", "sqrt"]:
|
||||||
|
with self.subTest(op=op):
|
||||||
|
np_args = (x.astype(np.float32),)
|
||||||
|
self.__default_test(op, np_args)
|
||||||
|
|
||||||
|
def test_binary_ops(self):
|
||||||
|
x = np.random.rand(18, 28, 38)
|
||||||
|
y = np.random.rand(18, 28, 38)
|
||||||
|
for op in ["add", "subtract", "multiply", "divide", "maximum", "minimum"]:
|
||||||
|
with self.subTest(op=op):
|
||||||
|
np_args = (
|
||||||
|
x.astype(np.float32),
|
||||||
|
y.astype(np.float32),
|
||||||
|
)
|
||||||
|
self.__default_test(op, np_args, simple_transform=lambda x: x)
|
||||||
|
self.__default_test(op, np_args, simple_transform=lambda x: x[:1])
|
||||||
|
self.__default_test(op, np_args, simple_transform=lambda x: x[:, :1])
|
||||||
|
|
||||||
|
def test_reduction_ops(self):
|
||||||
|
x = np.random.rand(18, 28, 38).astype(np.float32)
|
||||||
|
|
||||||
|
for op in ("min", "max"):
|
||||||
|
with self.subTest(op=op):
|
||||||
|
|
||||||
|
for axes in (0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)):
|
||||||
|
with self.subTest(axes=axes):
|
||||||
|
np_args = (x.astype(np.float32),)
|
||||||
|
self.__default_test(
|
||||||
|
op,
|
||||||
|
np_args,
|
||||||
|
np_kwargs={"axis": axes},
|
||||||
|
mlx_kwargs={"axis": axes},
|
||||||
|
torch_kwargs={"dim": axes},
|
||||||
|
torch_op="a" + op,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_arg_reduction_ops(self):
|
||||||
|
data = np.random.rand(10, 12, 13).astype(np.float32)
|
||||||
|
x = mx.array(data).astype(mx.bfloat16)
|
||||||
|
data = np.asarray(x.astype(mx.float32))
|
||||||
|
|
||||||
|
for op in ["argmin", "argmax"]:
|
||||||
|
for axis in range(3):
|
||||||
|
for kd in [True, False]:
|
||||||
|
a = getattr(mx, op)(x, axis, kd)
|
||||||
|
b = getattr(np, op)(data, axis, keepdims=kd)
|
||||||
|
a = a.astype(mx.float32)
|
||||||
|
self.assertEqual(a.tolist(), b.tolist())
|
||||||
|
|
||||||
|
for op in ["argmin", "argmax"]:
|
||||||
|
a = getattr(mx, op)(x, keepdims=True)
|
||||||
|
b = getattr(np, op)(data, keepdims=True)
|
||||||
|
a = a.astype(mx.float32)
|
||||||
|
self.assertEqual(a.tolist(), b.tolist())
|
||||||
|
a = getattr(mx, op)(x)
|
||||||
|
b = getattr(np, op)(data)
|
||||||
|
a = a.astype(mx.float32)
|
||||||
|
self.assertEqual(a.item(), b)
|
||||||
|
|
||||||
|
def test_blas_ops(self):
|
||||||
|
if mx.default_device() != mx.gpu:
|
||||||
|
return
|
||||||
|
|
||||||
|
def test_blas(shape_x, shape_y):
|
||||||
|
np.random.seed(42)
|
||||||
|
with self.subTest(shape_x=shape_x, shape_y=shape_y):
|
||||||
|
x = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_x)
|
||||||
|
y = np.random.normal(0.0, 1.0 / shape_x[-1], size=shape_y)
|
||||||
|
|
||||||
|
np_args = (
|
||||||
|
x.astype(np.float32),
|
||||||
|
y.astype(np.float32),
|
||||||
|
)
|
||||||
|
op = "matmul"
|
||||||
|
|
||||||
|
self.__default_test(op, np_args, atol_np=1e-3, atol_torch=1e-3)
|
||||||
|
|
||||||
|
for shape_x, shape_y in [
|
||||||
|
[(32, 32), (32, 32)],
|
||||||
|
[(23, 57), (57, 1)],
|
||||||
|
[(1, 3), (3, 128)],
|
||||||
|
[(8, 128, 768), (768, 16)],
|
||||||
|
]:
|
||||||
|
test_blas(shape_x, shape_y)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
224
tests/creations_tests.cpp
Normal file
224
tests/creations_tests.cpp
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
TEST_CASE("test arange") {
|
||||||
|
// Check type is inferred correclty
|
||||||
|
{
|
||||||
|
auto x = arange(10);
|
||||||
|
CHECK_EQ(x.dtype(), int32);
|
||||||
|
|
||||||
|
x = arange(10.0);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
|
||||||
|
x = arange(10, float32);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
|
||||||
|
x = arange(10.0, int32);
|
||||||
|
CHECK_EQ(x.dtype(), int32);
|
||||||
|
|
||||||
|
x = arange(0, 10);
|
||||||
|
CHECK_EQ(x.dtype(), int32);
|
||||||
|
|
||||||
|
x = arange(0.0, 10.0, int32);
|
||||||
|
CHECK_EQ(x.dtype(), int32);
|
||||||
|
|
||||||
|
x = arange(0.0, 10.0);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
|
||||||
|
x = arange(0, 10, float32);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
|
||||||
|
x = arange(0, 10, 0.1, float32);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
|
||||||
|
x = arange(0.0, 10.0, 0.5, int32);
|
||||||
|
CHECK_EQ(x.dtype(), int32);
|
||||||
|
|
||||||
|
x = arange(10.0, uint32);
|
||||||
|
CHECK_EQ(x.dtype(), uint32);
|
||||||
|
x = arange(0.0, 10.0, uint32);
|
||||||
|
CHECK_EQ(x.dtype(), uint32);
|
||||||
|
x = arange(0.0, 10.0, 0.5, uint32);
|
||||||
|
CHECK_EQ(x.dtype(), uint32);
|
||||||
|
|
||||||
|
// arange unsupported for bool_
|
||||||
|
CHECK_THROWS_AS(arange(10, bool_), std::invalid_argument);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check correct sizes
|
||||||
|
{
|
||||||
|
auto x = arange(10);
|
||||||
|
CHECK_EQ(x.size(), 10);
|
||||||
|
|
||||||
|
x = arange(0.0, 10.0, 0.5);
|
||||||
|
CHECK_EQ(x.size(), 20);
|
||||||
|
|
||||||
|
x = arange(0.0, 10.0, 0.45);
|
||||||
|
CHECK_EQ(x.size(), 23);
|
||||||
|
|
||||||
|
x = arange(0, 10, 10);
|
||||||
|
CHECK_EQ(x.size(), 1);
|
||||||
|
|
||||||
|
x = arange(0, 10, 9);
|
||||||
|
CHECK_EQ(x.size(), 2);
|
||||||
|
|
||||||
|
x = arange(0, 10, 100);
|
||||||
|
CHECK_EQ(x.size(), 1);
|
||||||
|
|
||||||
|
x = arange(0, -10, 1);
|
||||||
|
CHECK_EQ(x.size(), 0);
|
||||||
|
|
||||||
|
x = arange(0, -10, -1);
|
||||||
|
CHECK_EQ(x.size(), 10);
|
||||||
|
|
||||||
|
x = arange(0, -10, -10);
|
||||||
|
CHECK_EQ(x.size(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check values
|
||||||
|
{
|
||||||
|
auto x = arange(0, 3);
|
||||||
|
CHECK(array_equal(x, array({0, 1, 2})).item<bool>());
|
||||||
|
|
||||||
|
x = arange(0, 3, 2);
|
||||||
|
CHECK(array_equal(x, array({0, 2})).item<bool>());
|
||||||
|
|
||||||
|
x = arange(0, 3, 3);
|
||||||
|
CHECK(array_equal(x, array({0})).item<bool>());
|
||||||
|
|
||||||
|
x = arange(0, -3, 1);
|
||||||
|
CHECK(array_equal(x, array({})).item<bool>());
|
||||||
|
|
||||||
|
x = arange(0, 3, -1);
|
||||||
|
CHECK(array_equal(x, array({})).item<bool>());
|
||||||
|
|
||||||
|
x = arange(0, -3, -1);
|
||||||
|
CHECK(array_equal(x, array({0, -1, -2})).item<bool>());
|
||||||
|
|
||||||
|
x = arange(0.0, 5.0, 0.5, int32);
|
||||||
|
CHECK(array_equal(x, zeros({10})).item<bool>());
|
||||||
|
|
||||||
|
x = arange(0.0, 5.0, 1.5, int32);
|
||||||
|
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test astype") {
|
||||||
|
// Check type conversions
|
||||||
|
{
|
||||||
|
auto x = array(1);
|
||||||
|
auto y = astype(x, float32);
|
||||||
|
CHECK_EQ(y.dtype(), float32);
|
||||||
|
CHECK_EQ(y.item<float>(), 1.0f);
|
||||||
|
|
||||||
|
y = astype(x, int32);
|
||||||
|
CHECK_EQ(y.dtype(), int32);
|
||||||
|
CHECK_EQ(y.item<int>(), 1);
|
||||||
|
|
||||||
|
x = array(-3.0f);
|
||||||
|
y = astype(x, int32);
|
||||||
|
CHECK_EQ(y.dtype(), int32);
|
||||||
|
CHECK_EQ(y.item<int>(), -3);
|
||||||
|
|
||||||
|
y = astype(x, uint32);
|
||||||
|
CHECK_EQ(y.dtype(), uint32);
|
||||||
|
|
||||||
|
// Use std::copy since the result is platform dependent
|
||||||
|
uint32_t v;
|
||||||
|
std::copy(x.data<float>(), x.data<float>() + 1, &v);
|
||||||
|
CHECK_EQ(y.item<uint32_t>(), v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test full") {
|
||||||
|
// Check full works for different types
|
||||||
|
{
|
||||||
|
auto x = full({}, 0);
|
||||||
|
CHECK_EQ(x.dtype(), int32);
|
||||||
|
CHECK_EQ(x.item<int>(), 0);
|
||||||
|
|
||||||
|
x = full({}, 0.0);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
CHECK_EQ(x.item<float>(), 0);
|
||||||
|
|
||||||
|
x = full({}, false);
|
||||||
|
CHECK_EQ(x.item<bool>(), false);
|
||||||
|
|
||||||
|
x = full({}, 0, int32);
|
||||||
|
CHECK_EQ(x.item<int>(), 0);
|
||||||
|
|
||||||
|
x = full({}, 0, float32);
|
||||||
|
CHECK_EQ(x.item<float>(), 0);
|
||||||
|
|
||||||
|
x = full({1, 2}, 2, float32);
|
||||||
|
CHECK(array_equal(x, array({2.0, 2.0}, {1, 2})).item<bool>());
|
||||||
|
|
||||||
|
x = full({2, 1}, 2, float32);
|
||||||
|
CHECK(array_equal(x, array({2.0, 2.0}, {2, 1})).item<bool>());
|
||||||
|
|
||||||
|
x = full({2}, false);
|
||||||
|
CHECK_EQ(x.dtype(), bool_);
|
||||||
|
CHECK(array_equal(x, array({false, false})).item<bool>());
|
||||||
|
|
||||||
|
x = full({2}, 1.0, bool_);
|
||||||
|
CHECK_EQ(x.dtype(), bool_);
|
||||||
|
CHECK(array_equal(x, array({true, true})).item<bool>());
|
||||||
|
|
||||||
|
x = full({2}, 1.0, uint32);
|
||||||
|
CHECK_EQ(x.dtype(), uint32);
|
||||||
|
CHECK(array_equal(x, array({1, 1})).item<bool>());
|
||||||
|
|
||||||
|
CHECK_THROWS_AS(full({2}, array({})), std::invalid_argument);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check broadcasting works
|
||||||
|
{
|
||||||
|
auto x = full({2, 2}, array({3, 4}, {2, 1}));
|
||||||
|
CHECK(array_equal(x, array({3, 3, 4, 4}, {2, 2})).item<bool>());
|
||||||
|
x = full({2, 2}, array({3, 4}, {1, 2}));
|
||||||
|
CHECK(array_equal(x, array({3, 4, 3, 4}, {2, 2})).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check zeros and ones
|
||||||
|
{
|
||||||
|
auto x = zeros({2, 2}, float32);
|
||||||
|
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
|
||||||
|
CHECK_EQ(x.ndim(), 2);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2});
|
||||||
|
CHECK(array_equal(x, y).item<bool>());
|
||||||
|
|
||||||
|
x = ones({2, 2}, float32);
|
||||||
|
CHECK_EQ(x.shape(), std::vector<int>{2, 2});
|
||||||
|
CHECK_EQ(x.ndim(), 2);
|
||||||
|
CHECK_EQ(x.dtype(), float32);
|
||||||
|
y = array({1.0, 1.0, 1.0, 1.0}, {2, 2});
|
||||||
|
CHECK(array_equal(x, y).item<bool>());
|
||||||
|
|
||||||
|
x = zeros({2, 2}, int32);
|
||||||
|
y = zeros_like(x);
|
||||||
|
CHECK_EQ(y.dtype(), int32);
|
||||||
|
CHECK(array_equal(x, y).item<bool>());
|
||||||
|
|
||||||
|
x = ones({2, 2}, int32);
|
||||||
|
y = ones_like(x);
|
||||||
|
CHECK_EQ(y.dtype(), int32);
|
||||||
|
CHECK(array_equal(x, y).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Works for empty shape and empty array
|
||||||
|
{
|
||||||
|
array x = ones({}, int32);
|
||||||
|
CHECK_EQ(x.shape(), std::vector<int>{});
|
||||||
|
CHECK_EQ(x.item<int>(), 1);
|
||||||
|
|
||||||
|
x = full({0}, array({}));
|
||||||
|
CHECK_EQ(x.shape(), std::vector<int>{0});
|
||||||
|
CHECK_EQ(x.size(), 0);
|
||||||
|
|
||||||
|
CHECK_THROWS_AS(full({}, array({})), std::invalid_argument);
|
||||||
|
}
|
||||||
|
}
|
331
tests/fft_tests.cpp
Normal file
331
tests/fft_tests.cpp
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
TEST_CASE("test fft basics") {
|
||||||
|
auto device = default_device();
|
||||||
|
set_default_device(Device::cpu);
|
||||||
|
array x(1.0);
|
||||||
|
CHECK_THROWS(fft::fft(x));
|
||||||
|
CHECK_THROWS(fft::ifft(x));
|
||||||
|
|
||||||
|
x = array({1.0});
|
||||||
|
auto y = fft::fft(x);
|
||||||
|
CHECK_EQ(y.dtype(), complex64);
|
||||||
|
CHECK_EQ(y.size(), x.size());
|
||||||
|
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});
|
||||||
|
|
||||||
|
y = fft::ifft(x);
|
||||||
|
CHECK_EQ(y.dtype(), complex64);
|
||||||
|
CHECK_EQ(y.size(), x.size());
|
||||||
|
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});
|
||||||
|
|
||||||
|
x = array({complex64_t{1.0f, 1.0f}}, complex64);
|
||||||
|
y = fft::fft(x);
|
||||||
|
CHECK_EQ(y.size(), x.size());
|
||||||
|
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 1.0f});
|
||||||
|
|
||||||
|
y = fft::ifft(x);
|
||||||
|
CHECK_EQ(y.dtype(), complex64);
|
||||||
|
CHECK_EQ(y.size(), x.size());
|
||||||
|
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 1.0f});
|
||||||
|
|
||||||
|
{
|
||||||
|
x = array({0.0f, 1.0f, 2.0f, 3.0f});
|
||||||
|
y = fft::fft(x);
|
||||||
|
std::initializer_list<complex64_t> expected = {
|
||||||
|
{6.0, 0.0},
|
||||||
|
{-2.0, 2.0},
|
||||||
|
{-2.0, 0.0},
|
||||||
|
{-2.0, -2.0},
|
||||||
|
};
|
||||||
|
CHECK_EQ(y.size(), x.size());
|
||||||
|
CHECK(array_equal(y, array(expected)).item<bool>());
|
||||||
|
|
||||||
|
y = fft::ifft(x);
|
||||||
|
std::initializer_list<complex64_t> expected_inv = {
|
||||||
|
{1.5, 0.0},
|
||||||
|
{-0.5, -0.5},
|
||||||
|
{-0.5, 0.0},
|
||||||
|
{-0.5, 0.5},
|
||||||
|
};
|
||||||
|
CHECK(array_equal(y, array(expected_inv)).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::initializer_list<complex64_t> vals = {
|
||||||
|
{1.0f, 1.0f}, {2.0f, 1.0f}, {1.0f, 2.0f}, {2.0f, 2.0f}};
|
||||||
|
x = array(vals);
|
||||||
|
y = fft::fft(x);
|
||||||
|
std::initializer_list<complex64_t> expected = {
|
||||||
|
{6.0, 6.0},
|
||||||
|
{-1.0, -1.0},
|
||||||
|
{-2.0, 0.0},
|
||||||
|
{1.0, -1.0},
|
||||||
|
};
|
||||||
|
CHECK_EQ(y.size(), x.size());
|
||||||
|
CHECK(array_equal(y, array(expected)).item<bool>());
|
||||||
|
CHECK(array_equal(fft::ifft(y), x).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specify axes
|
||||||
|
{
|
||||||
|
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});
|
||||||
|
std::initializer_list<complex64_t> expected_0 = {
|
||||||
|
{2.0, 0.0},
|
||||||
|
{4.0, 0.0},
|
||||||
|
{-2.0, 0.0},
|
||||||
|
{-2.0, 0.0},
|
||||||
|
};
|
||||||
|
y = fft::fft(x, 0);
|
||||||
|
CHECK(array_equal(y, array(expected_0, {2, 2})).item<bool>());
|
||||||
|
CHECK(array_equal(fft::ifft(y, 0), x).item<bool>());
|
||||||
|
std::initializer_list<complex64_t> expected_1 = {
|
||||||
|
{1.0, 0.0},
|
||||||
|
{-1.0, 0.0},
|
||||||
|
{5.0, 0.0},
|
||||||
|
{-1.0, 0.0},
|
||||||
|
};
|
||||||
|
y = fft::fft(x, 1);
|
||||||
|
CHECK(array_equal(y, array(expected_1, {2, 2})).item<bool>());
|
||||||
|
CHECK(array_equal(fft::ifft(y, 1), x).item<bool>());
|
||||||
|
}
|
||||||
|
set_default_device(device);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test real ffts") {
|
||||||
|
auto device = default_device();
|
||||||
|
set_default_device(Device::cpu);
|
||||||
|
|
||||||
|
auto x = array({1.0});
|
||||||
|
auto y = fft::rfft(x);
|
||||||
|
CHECK_EQ(y.dtype(), complex64);
|
||||||
|
CHECK_EQ(y.size(), x.size());
|
||||||
|
CHECK_EQ(y.item<complex64_t>(), complex64_t{1.0f, 0.0f});
|
||||||
|
|
||||||
|
{
|
||||||
|
x = array({0.0f, 1.0f, 2.0f, 3.0f});
|
||||||
|
y = fft::rfft(x);
|
||||||
|
std::initializer_list<complex64_t> expected = {
|
||||||
|
{6.0, 0.0}, {-2.0, 2.0}, {-2.0, -0.0}};
|
||||||
|
CHECK_EQ(y.size(), x.size() / 2 + 1);
|
||||||
|
CHECK(array_equal(y, array(expected)).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
x = array(complex64_t{1, 1});
|
||||||
|
CHECK_THROWS(fft::irfft(x));
|
||||||
|
|
||||||
|
x = array({complex64_t{0, 1}, complex64_t{1, 0}});
|
||||||
|
y = fft::irfft(x);
|
||||||
|
CHECK_EQ(y.size(), 2);
|
||||||
|
CHECK_EQ(y.dtype(), float32);
|
||||||
|
CHECK(array_equal(y, array({0.5f, -0.5f})).item<bool>());
|
||||||
|
|
||||||
|
set_default_device(device);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test fftn") {
|
||||||
|
auto device = default_device();
|
||||||
|
set_default_device(Device::cpu);
|
||||||
|
|
||||||
|
auto x = zeros({5, 5, 5});
|
||||||
|
CHECK_THROWS_AS(fft::fftn(x, {}, {0, 3}), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(fft::fftn(x, {}, {0, -4}), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(fft::fftn(x, {}, {0, 0}), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(fft::fftn(x, {5, 5, 5}, {0}), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(fft::fftn(x, {0}, {}, {}), std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(fft::fftn(x, {1, -1}, {}, {}), std::invalid_argument);
|
||||||
|
|
||||||
|
// Test 2D FFT
|
||||||
|
{
|
||||||
|
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});
|
||||||
|
std::initializer_list<complex64_t> expected = {
|
||||||
|
{6.0, 0.0},
|
||||||
|
{-2.0, 0.0},
|
||||||
|
{-4.0, 0.0},
|
||||||
|
{0.0, 0.0},
|
||||||
|
};
|
||||||
|
auto y = fft::fft2(x);
|
||||||
|
CHECK(array_equal(y, array(expected, {2, 2})).item<bool>());
|
||||||
|
CHECK(array_equal(fft::ifft2(y), x).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 3D FFT
|
||||||
|
{
|
||||||
|
x = reshape(arange(8, float32), {2, 2, 2});
|
||||||
|
std::initializer_list<complex64_t> expected = {
|
||||||
|
{28.0, 0.0},
|
||||||
|
{-4.0, 0.0},
|
||||||
|
{-8.0, 0.0},
|
||||||
|
{0.0, 0.0},
|
||||||
|
{-16.0, 0.0},
|
||||||
|
{0.0, 0.0},
|
||||||
|
{0.0, 0.0},
|
||||||
|
{0.0, 0.0},
|
||||||
|
};
|
||||||
|
auto y = fft::fftn(x);
|
||||||
|
CHECK(array_equal(y, array(expected, {2, 2, 2})).item<bool>());
|
||||||
|
CHECK(array_equal(fft::ifftn(y), x).item<bool>());
|
||||||
|
|
||||||
|
x = reshape(arange(20, float32), {5, 4});
|
||||||
|
y = fft::rfftn(x);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||||
|
y = fft::rfftn(x, {1, 0});
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{3, 4});
|
||||||
|
|
||||||
|
x = reshape(arange(20, float32), {5, 4});
|
||||||
|
y = fft::irfftn(x);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{5, 6});
|
||||||
|
y = fft::irfftn(x, {1, 0});
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{8, 4});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the types of real ffts
|
||||||
|
{
|
||||||
|
x = zeros({5, 5}, float32);
|
||||||
|
auto y = fft::rfft2(x);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||||
|
CHECK_EQ(y.dtype(), complex64);
|
||||||
|
|
||||||
|
y = fft::rfftn(x);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||||
|
CHECK_EQ(y.dtype(), complex64);
|
||||||
|
|
||||||
|
x = zeros({5, 5}, complex64);
|
||||||
|
y = fft::irfft2(x);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
|
||||||
|
CHECK_EQ(y.dtype(), float32);
|
||||||
|
|
||||||
|
y = fft::irfftn(x);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{5, 8});
|
||||||
|
CHECK_EQ(y.dtype(), float32);
|
||||||
|
}
|
||||||
|
|
||||||
|
set_default_device(device);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test fft with provided shape") {
|
||||||
|
auto x = ones({5, 5});
|
||||||
|
|
||||||
|
auto y = fft::fft(x, 7, 0);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{7, 5});
|
||||||
|
|
||||||
|
y = fft::fft(x, 3, 0);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{3, 5});
|
||||||
|
|
||||||
|
y = fft::fft(x, 7, 1);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{5, 7});
|
||||||
|
|
||||||
|
y = fft::fft(x, 3, 1);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{5, 3});
|
||||||
|
|
||||||
|
y = fft::rfft(x, 7, 0);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{4, 5});
|
||||||
|
|
||||||
|
y = fft::rfft(x, 3, 0);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
|
||||||
|
|
||||||
|
y = fft::rfft(x, 3, 1);
|
||||||
|
CHECK_EQ(y.shape(), std::vector<int>{5, 2});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test fft vmap") {
|
||||||
|
auto device = default_device();
|
||||||
|
set_default_device(Device::cpu);
|
||||||
|
|
||||||
|
auto fft_fn = [](array x) { return fft::fft(x); };
|
||||||
|
auto x = reshape(arange(8), {2, 4});
|
||||||
|
auto y = vmap(fft_fn)(x);
|
||||||
|
CHECK(array_equal(y, fft::fft(x)).item<bool>());
|
||||||
|
|
||||||
|
y = vmap(fft_fn, 1, 1)(x);
|
||||||
|
CHECK(array_equal(y, fft::fft(x, 0)).item<bool>());
|
||||||
|
|
||||||
|
auto rfft_fn = [](array x) { return fft::rfft(x); };
|
||||||
|
|
||||||
|
y = vmap(rfft_fn)(x);
|
||||||
|
CHECK(array_equal(y, fft::rfft(x)).item<bool>());
|
||||||
|
|
||||||
|
y = vmap(rfft_fn, 1, 1)(x);
|
||||||
|
CHECK(array_equal(y, fft::rfft(x, 0)).item<bool>());
|
||||||
|
|
||||||
|
set_default_device(device);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test fft grads") {
|
||||||
|
auto device = default_device();
|
||||||
|
set_default_device(Device::cpu);
|
||||||
|
|
||||||
|
// Regular
|
||||||
|
auto fft_fn = [](array x) { return fft::fft(x); };
|
||||||
|
auto cotangent = astype(arange(10), complex64);
|
||||||
|
auto vjp_out = vjp(fft_fn, zeros_like(cotangent), cotangent).second;
|
||||||
|
CHECK(array_equal(fft::fft(cotangent), vjp_out).item<bool>());
|
||||||
|
|
||||||
|
auto tangent = astype(arange(10), complex64);
|
||||||
|
auto jvp_out = jvp(fft_fn, zeros_like(tangent), tangent).second;
|
||||||
|
CHECK(array_equal(fft::fft(tangent), jvp_out).item<bool>());
|
||||||
|
|
||||||
|
// Inverse
|
||||||
|
auto ifft_fn = [](array x) { return fft::ifft(x); };
|
||||||
|
vjp_out = vjp(ifft_fn, zeros_like(cotangent), cotangent).second;
|
||||||
|
CHECK(array_equal(fft::ifft(cotangent), vjp_out).item<bool>());
|
||||||
|
|
||||||
|
jvp_out = jvp(ifft_fn, zeros_like(tangent), tangent).second;
|
||||||
|
CHECK(array_equal(fft::ifft(tangent), jvp_out).item<bool>());
|
||||||
|
|
||||||
|
// Real
|
||||||
|
auto rfft_fn = [](array x) { return fft::rfft(x); };
|
||||||
|
cotangent = astype(arange(6), complex64);
|
||||||
|
vjp_out = vjp(rfft_fn, zeros({10}), cotangent).second;
|
||||||
|
auto expected = astype(fft::fft(cotangent, 10, 0), float32);
|
||||||
|
CHECK(array_equal(expected, vjp_out).item<bool>());
|
||||||
|
|
||||||
|
tangent = astype(arange(10), float32);
|
||||||
|
jvp_out = jvp(rfft_fn, zeros_like(tangent), tangent).second;
|
||||||
|
CHECK(array_equal(fft::rfft(tangent), jvp_out).item<bool>());
|
||||||
|
|
||||||
|
// Inverse real
|
||||||
|
auto irfft_fn = [](array x) { return fft::irfft(x); };
|
||||||
|
cotangent = astype(arange(10), float32);
|
||||||
|
vjp_out = vjp(irfft_fn, astype(zeros({6}), complex64), cotangent).second;
|
||||||
|
expected = fft::fft(cotangent, 10, 0);
|
||||||
|
auto o_splits = split(vjp_out, {1, 5});
|
||||||
|
auto e_splits = split(expected, {1, 5, 6});
|
||||||
|
CHECK_EQ(e_splits[0].item<complex64_t>(), o_splits[0].item<complex64_t>());
|
||||||
|
CHECK(array_equal(2 * e_splits[1], o_splits[1]).item<bool>());
|
||||||
|
CHECK_EQ(e_splits[2].item<complex64_t>(), o_splits[2].item<complex64_t>());
|
||||||
|
|
||||||
|
tangent = astype(arange(10), complex64);
|
||||||
|
jvp_out = jvp(irfft_fn, zeros_like(tangent), tangent).second;
|
||||||
|
CHECK(array_equal(fft::irfft(tangent), jvp_out).item<bool>());
|
||||||
|
|
||||||
|
// Check ND vjps run properly
|
||||||
|
vjp_out = vjp([](array x) { return fft::fftn(x); },
|
||||||
|
astype(zeros({5, 5}), complex64),
|
||||||
|
astype(zeros({5, 5}), complex64))
|
||||||
|
.second;
|
||||||
|
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||||
|
|
||||||
|
vjp_out = vjp([](array x) { return fft::ifftn(x); },
|
||||||
|
astype(zeros({5, 5}), complex64),
|
||||||
|
astype(zeros({5, 5}), complex64))
|
||||||
|
.second;
|
||||||
|
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||||
|
|
||||||
|
vjp_out = vjp([](array x) { return fft::rfftn(x); },
|
||||||
|
zeros({5, 9}),
|
||||||
|
astype(zeros({5, 5}), complex64))
|
||||||
|
.second;
|
||||||
|
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 9});
|
||||||
|
|
||||||
|
vjp_out = vjp([](array x) { return fft::irfftn(x); },
|
||||||
|
astype(zeros({5, 5}), complex64),
|
||||||
|
zeros({5, 8}))
|
||||||
|
.second;
|
||||||
|
CHECK_EQ(vjp_out.shape(), std::vector<int>{5, 5});
|
||||||
|
|
||||||
|
set_default_device(device);
|
||||||
|
}
|
30
tests/graph_optimize_tests.cpp
Normal file
30
tests/graph_optimize_tests.cpp
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
TEST_CASE("test simplify scalars") {
|
||||||
|
auto a = array({-1.0f, 2.0f});
|
||||||
|
auto b = maximum(a, array(0.0f));
|
||||||
|
auto c = maximum(-a, array(0.0f));
|
||||||
|
auto d = b + c;
|
||||||
|
simplify({d});
|
||||||
|
CHECK(b.inputs()[1].id() == c.inputs()[1].id());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test simplify") {
|
||||||
|
auto a = array({1.0f, 2.0f});
|
||||||
|
auto b = exp(a) + exp(a);
|
||||||
|
simplify(b);
|
||||||
|
eval(b);
|
||||||
|
CHECK(b.inputs()[0].id() == b.inputs()[1].id());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test no simplify") {
|
||||||
|
auto a = array({1.0f, 2.0f});
|
||||||
|
auto b = cos(a) + sin(a);
|
||||||
|
simplify(b);
|
||||||
|
eval(b);
|
||||||
|
CHECK(b.inputs()[0].id() != b.inputs()[1].id());
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user