Compare commits

..

4 Commits

Author SHA1 Message Date
Awni Hannun
0dbe80a024 try again with checkpointed classes 2024-03-06 10:38:04 -08:00
Awni Hannun
a5827d0384 docs for checkpoint + a few more tests 2024-03-06 10:38:04 -08:00
Awni Hannun
1368bce280 fix tests and add setter attributes 2024-03-06 10:38:04 -08:00
Awni Hannun
8918a437bb checkpoint module's __call__ 2024-03-06 10:38:04 -08:00
285 changed files with 9765 additions and 26541 deletions

View File

@@ -31,7 +31,8 @@ jobs:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
@@ -43,12 +44,16 @@ jobs:
- run:
name: Generate package stubs
command: |
echo "stubs"
python setup.py generate_stubs
python3 setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3 -m pip install .
- run:
name: Build CPP only
command: |
@@ -58,24 +63,21 @@ jobs:
command: ./build/tests/tests
mac_build_and_test:
parameters:
xcode_version:
type: string
default: "15.2.0"
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
xcode: "15.2.0"
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.8
python3.8 -m venv env
brew install python@3.9
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
pip install torch
pip install tensorflow
@@ -89,17 +91,18 @@ jobs:
name: Generate package stubs
command: |
source env/bin/activate
python setup.py generate_stubs
python setup.py generate_stubs
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
- run:
name: Build example extension
command: |
cd examples/extensions && python3.8 -m pip install .
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3.11 -m pip install .
- store_test_results:
path: test-results
- run:
@@ -126,7 +129,7 @@ jobs:
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
@@ -137,8 +140,9 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install twine
pip install build
@@ -153,7 +157,7 @@ jobs:
name: Generate package stubs
command: |
source env/bin/activate
python setup.py generate_stubs
python setup.py generate_stubs
- run:
name: Build Python package
command: |
@@ -201,8 +205,9 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install --upgrade pybind11[global]
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install auditwheel
pip install patchelf
@@ -210,7 +215,7 @@ jobs:
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
python setup.py generate_stubs
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build --wheel
@@ -230,10 +235,7 @@ workflows:
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0"]
- mac_build_and_test
- linux_build_and_test
build_pypi_release:
@@ -252,7 +254,7 @@ workflows:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"]
xcode_version: ["14.3.1", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
prb:
when:
@@ -266,9 +268,6 @@ workflows:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0"]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
@@ -281,7 +280,7 @@ workflows:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"]
xcode_version: ["14.3.1", "15.2.0"]
weekly_build:
when:
and:
@@ -292,7 +291,7 @@ workflows:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["15.0.0", "15.2.0"]
xcode_version: ["14.3.1", "15.2.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:

View File

@@ -1,11 +1,11 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.4
rev: v17.0.6
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort

View File

@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
@@ -15,8 +15,6 @@ MLX was developed with contributions from the following individuals:
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>

View File

@@ -15,36 +15,31 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.13.1)
set(MLX_VERSION 0.5.1)
endif()
# --------------------- Processor tests -------------------------
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
set(MLX_BUILD_ARM OFF)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC)
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
else()
message(WARNING "Building for x86_64 arch is not officially supported.")
endif()
set(MLX_BUILD_METAL OFF)
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message(WARNING
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, "
" make sure you are building for arm64.")
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
@@ -69,14 +64,8 @@ endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
if (MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
# Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION
@@ -85,21 +74,18 @@ elseif (MLX_BUILD_METAL)
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
set(MLX_METAL_VERSION METAL_3_1)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
set(MLX_METAL_VERSION METAL_3_0)
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
else()
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
endif()
FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true
)
FetchContent_MakeAvailable(metal_cpp)
@@ -113,57 +99,42 @@ elseif (MLX_BUILD_METAL)
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
add_compile_definitions(${MLX_METAL_VERSION})
endif()
if (MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
endif()
find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
#set(BLA_VENDOR Generic)
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
@@ -177,12 +148,8 @@ target_include_directories(
if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()

View File

@@ -17,13 +17,14 @@
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
<< std::flush << std::setprecision(5) \
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
#define TIMEM(MSG, FUNC, ...) \
std::cout << "Timing " \
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
<< std::endl;
template <typename F, typename... Args>
double time_fn(F fn, Args&&... args) {
double time_fn(F fn, Args... args) {
// warmup
for (int i = 0; i < 5; ++i) {
eval(fn(std::forward<Args>(args)...));

View File

@@ -1,123 +0,0 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_1D(strides=1, padding=0, groups=1):
def mx_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_1D
def make_pt_conv_1D(strides=1, padding=0, groups=1):
@torch.no_grad()
def pt_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_1D
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
scale = 1.0 / math.sqrt(wH * C)
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_1D(strides, padding, groups)
f_pt = make_pt_conv_1D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv1d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 5, 32, 1, 2, 1),
(4, 32, 32, 5, 32, 1, 2, 2),
(4, 32, 32, 5, 32, 1, 2, 4),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 16),
(4, 32, 32, 5, 32, 1, 2, 32),
(4, 32, 256, 5, 512, 1, 2, 2),
(4, 32, 256, 5, 512, 1, 2, 128),
(4, 32, 256, 5, 512, 1, 2, 256),
)
for dtype in dtypes:
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
for N, iH, C, wH, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, iH, C, wH, O, strides, padding, np_dtype, groups
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -1,57 +0,0 @@
# Copyright © 2024 Apple Inc.
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def bandwidth_gb(runtime_ms, system_size):
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
bytes_per_gb = 1e9
ms_per_s = 1e3
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size):
def fft(x):
out = mx.fft.fft(x)
mx.eval(out)
return out
bandwidths = []
for k in range(4, 12):
n = 2**k
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
x = x.astype(mx.complex64)
mx.eval(x)
runtime_ms = measure_runtime(fft, x=x)
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
return bandwidths
def time_fft():
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=int(2**22))
with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=int(2**29))
# plot bandwidths
x = [2**k for k in range(4, 12)]
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
plt.title("MLX FFT Benchmark")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig("fft_plot.png")
if __name__ == "__main__":
time_fft()

View File

@@ -1,41 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def layer_norm(x, w, b, eps):
ot = x.dtype
x = x.astype(mx.float32)
mu = mx.mean(x, -1, keepdims=True)
v = mx.var(x, -1, keepdims=True)
return (x - mu) * mx.rsqrt(v + eps) * w + b
def time_layer_norm():
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, b, y)
def layer_norm_loop(g, x, w, b):
gx, gw, gb = x, w, b
for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb
time_fn(layer_norm_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
if __name__ == "__main__":
time_layer_norm()

View File

@@ -1,39 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn
def rms_norm(x, w, eps):
ot = x.dtype
x = x.astype(mx.float32)
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (x * n).astype(ot) * w
def time_rms_norm():
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1))
g2 = mx.grad(f2, argnums=(0, 1))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
mx.eval(x, w, y)
def rms_norm_loop(g, x, w):
gx, gw = x, w
for _ in range(32):
gx, gw = g(gx, gw, y)
return gx, gw
time_fn(rms_norm_loop, g1, x, w)
time_fn(rms_norm_loop, g2, x, w)
time_fn(rms_norm_loop, mx.compile(g1), x, w)
time_fn(rms_norm_loop, mx.compile(g2), x, w)
if __name__ == "__main__":
time_rms_norm()

View File

@@ -6,21 +6,21 @@ from time_utils import time_fn
def time_rope():
rope = nn.RoPE(64)
rope = nn.RoPE(4096)
# vec
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
mx.eval(x)
def rope_vec(x):
for _ in range(32):
x = rope(x, offset=100)
x = rope(x)
return x
time_fn(rope_vec, x)
# matrix
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
mx.eval(x)
def rope_mat(x):

View File

@@ -1,36 +0,0 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
@@ -1906,6 +1906,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

View File

@@ -1,36 +0,0 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
@@ -1918,6 +1918,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

View File

@@ -1,50 +0,0 @@
################################################################################
# Primary project setup. #
################################################################################
PROJECT_NAME = "MLX"
OUTPUT_DIRECTORY = build
XML_OUTPUT = xml
HTML_OUTPUT = html
STRIP_FROM_PATH = ../
INPUT = ../mlx
FILE_PATTERNS = *.h
EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = YES
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES
################################################################################
# Doxygen preprocessor / parser control. #
################################################################################
ENABLE_PREPROCESSING = YES
MACRO_EXPANSION = YES
EXPAND_ONLY_PREDEF = NO
SKIP_FUNCTION_MACROS = NO
################################################################################
# Compound extraction control. #
################################################################################
EXTRACT_ALL = YES
EXTRACT_PACKAGE = YES
EXTRACT_STATIC = YES
CASE_SENSE_NAMES = NO
################################################################################
# Docstring control / customization. #
################################################################################
JAVADOC_AUTOBRIEF = YES
################################################################################
# Warning suppression. #
################################################################################
QUIET = YES
WARN_IF_UNDOCUMENTED = NO

View File

@@ -2,16 +2,12 @@
### Setup (do once)
Install Doxygen:
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
for example with `conda`:
```
brew install doxygen
```
Install Python packages:
```
pip install -r requirements.txt
conda install sphinx
pip install sphinx-book-theme
```
### Build
@@ -19,7 +15,7 @@ pip install -r requirements.txt
Build the docs from `mlx/docs/`
```
doxygen && make html
make html
```
View the docs by running a server in `mlx/docs/build/html/`:

View File

@@ -1,3 +0,0 @@
sphinx
breathe
sphinx-book-theme

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 746 KiB

View File

@@ -1,20 +0,0 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != "__init__" %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}

View File

@@ -22,7 +22,6 @@ extensions = [
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"breathe",
]
python_use_unqualified_type_names = True
@@ -30,20 +29,16 @@ autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"https://docs.python.org/3": None,
"https://numpy.org/doc/stable/": None,
}
breathe_projects = {"mlx": "../build/xml"}
breathe_default_project = "mlx"
templates_path = ["_templates"]
html_static_path = ["_static"]
source_suffix = ".rst"
main_doc = "index"
master_doc = "index"
highlight_language = "python"
pygments_style = "sphinx"
add_module_names = False
# -- Options for HTML output -------------------------------------------------
@@ -64,22 +59,3 @@ html_theme_options = {
# -- Options for HTMLHelp output ---------------------------------------------
htmlhelp_basename = "mlx_doc"
def setup(app):
from sphinx.util import inspect
wrapped_isfunc = inspect.isfunction
def isfunc(obj):
type_name = str(type(obj))
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
return True
return wrapped_isfunc(obj)
inspect.isfunction = isfunc
# -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]

View File

@@ -3,5 +3,4 @@
Operations
==========
.. doxygengroup:: ops
:content-only:

View File

@@ -1,16 +1,24 @@
Custom Extensions in MLX
========================
Developer Documentation
=======================
You can extend MLX with custom operations on the CPU or GPU. This guide
explains how to do that with a simple example.
MLX provides a open and flexible backend to which users may add operations
and specialized implementations without much hassle. While the library supplies
efficient operations that can be used and composed for any number of
applications, there may arise cases where new functionalities or highly
optimized implementations are needed. For such cases, you may design and
implement your own operations that link to and build on top of :mod:`mlx.core`.
We will introduce the inner-workings of MLX and go over a simple example to
learn the steps involved in adding new operations to MLX with your own CPU
and GPU implementations.
Introducing the Example
-----------------------
Let's say you would like an operation that takes in two arrays, ``x`` and
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
and then adds them together to get the result ``z = alpha * x + beta * y``.
You can do that in MLX directly:
Let's say that you would like an operation that takes in two arrays,
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
respectively, and then adds them together to get the result
``z = alpha * x + beta * y``. Well, you can very easily do that by just
writing out a function as follows:
.. code-block:: python
@@ -19,35 +27,44 @@ You can do that in MLX directly:
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
This function performs that operation while leaving the implementation and
function transformations to MLX.
This function performs that operation while leaving the implementations and
differentiation to MLX.
However you may need to customize the underlying implementation, perhaps to
make it faster or for custom differentiation. In this tutorial we will go
through adding custom extensions. It will cover:
However, you work with vector math libraries often and realize that the
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
You would really like the part of your applications that does this operation
on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how to add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.
* The structure of the MLX library.
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
* Implementing a GPU operation using metal.
* Adding the ``vjp`` and ``jvp`` function transformation.
* Building a custom extension and binding it to python.
Well, what a coincidence! You are in the right place. Over the course of this
example, we will learn:
* The structure of the MLX library from the frontend API to the backend implementations.
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
* How to implement your own GPU implementation using metal.
* How to add your own ``vjp`` and ``jvp``.
* How to build your implementations, link them to MLX, and bind them to python.
Operations and Primitives
-------------------------
Operations in MLX build the computation graph. Primitives provide the rules for
evaluating and transforming the graph. Let's start by discussing operations in
more detail.
In one sentence, operations in MLX build the computation graph, and primitives
provide the rules for evaluation and transformations of said graph. Let's start
by discussing operations in more detail.
Operations
^^^^^^^^^^^
Operations are the front-end functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
Operations are the frontend functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
operations in the Python API (:ref:`ops`).
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
C++:
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
C++ API:
.. code-block:: C++
@@ -66,7 +83,10 @@ C++:
StreamOrDevice s = {} // Stream on which to schedule the operation
);
The simplest way to this operation is in terms of existing operations:
This operation itself can call other operations within it if needed. So, the
simplest way to go about implementing this operation would be do so in terms
of existing operations.
.. code-block:: C++
@@ -80,23 +100,25 @@ The simplest way to this operation is in terms of existing operations:
// Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s);
// Add and return
return add(ax, by, s);
}
The operations themselves do not contain the implementations that act on the
data, nor do they contain the rules of transformations. Rather, they are an
easy to use interface that use :class:`Primitive` building blocks.
However, as we discussed earlier, this is not our goal. The operations themselves
do not contain the implementations that act on the data, nor do they contain the
rules of transformations. Rather, they are an easy to use interface that build
on top of the building blocks we call :class:`Primitive`.
Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
more concrete:
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create an output given a set of input :class:`array` . Further,
a :class:`Primitive` is a class that contains rules on how it is evaluated
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
``jvp``. These words on their own can be a bit abstract, so lets take a step
back and go to our example to give ourselves a more concrete image.
.. code-block:: C++
@@ -112,15 +134,11 @@ more concrete:
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) override;
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
array jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
@@ -129,8 +147,7 @@ more concrete:
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
const std::vector<int>& argnums) override;
/**
* The primitive must know how to vectorize itself across
@@ -138,7 +155,7 @@ more concrete:
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
@@ -158,22 +175,22 @@ more concrete:
void eval(const std::vector<array>& inputs, array& out);
};
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
implementations of how the output array is produced given the inputs through
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
:meth:`Axpby::vmap`.
The :class:`Axpby` class derives from the base :class:`Primitive` class and
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
``beta`` as parameters. It then provides implementations of how the array ``out``
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
Using the Primitive
^^^^^^^^^^^^^^^^^^^
Using the Primitives
^^^^^^^^^^^^^^^^^^^^^
Operations can use this :class:`Primitive` to add a new :class:`array` to the
computation graph. An :class:`array` can be constructed by providing its data
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
inputs that are passed to the primitive.
Operations can use this :class:`Primitive` to add a new :class:`array` to
the computation graph. An :class:`array` can be constructed by providing its
data type, shape, the :class:`Primitive` that computes it, and the
:class:`array` inputs that are passed to the primitive.
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
.. code-block:: C++
@@ -206,7 +223,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
std::make_unique<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
@@ -221,26 +238,27 @@ This operation now handles the following:
Implementing the Primitive
--------------------------
No computation happens when we call the operation alone. The operation only
builds the computation graph. When we evaluate the output array, MLX schedules
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
No computation happens when we call the operation alone. In effect, the
operation only builds the computation graph. When we evaluate the output
array, MLX schedules the execution of the computation graph, and calls
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
stream/device specified by the user.
.. warning::
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
no memory has been allocated for the output array. It falls on the implementation
of these functions to allocate memory as needed.
of these functions to allocate memory as needed
Implementing the CPU Back-end
Implementing the CPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by implementing a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Let's start by trying to implement a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
point-wise. This is captured in the templated function :meth:`axpby_impl`.
Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
pointwise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++
@@ -278,19 +296,19 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
}
}
Our implementation should work for all incoming floating point arrays.
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
``complex64``. We throw an error if we encounter an unexpected type.
Now, we would like our implementation to be able to do this pointwise operation
for all incoming floating point arrays. Accordingly, we add dispatches for
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
if we encounter an unexpected type.
.. code-block:: C++
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
@@ -303,26 +321,28 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"[Axpby] Only supports floating point types.");
"Axpby is only supported for floating point types.");
}
}
This is good as a fallback implementation. We can use the ``axpby`` routine
provided by the Accelerate_ framework for a faster implementation in certain
cases:
We have a fallback implementation! Now, to do what we are really here to do.
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
framework? Well, there are 3 complications to keep in mind:
#. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only use it for ``float32`` types.
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
elements have fixed strides between them. We only direct to Accelerate
if both ``x`` and ``y`` are row contiguous or column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
MLX expects to write the output to a new array. We must copy the elements
of ``y`` into the output and use that as an input to ``axpby``.
floats. We can only direct to it for ``float32`` types
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
have fixed strides between them. Possibly due to broadcasts and transposes,
we aren't guaranteed that the inputs fit this requirement. We can
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
MLX expects to write out the answer to a new array. We must copy the elements
of ``y`` into the output array and use that as an input to ``axpby``
Let's write an implementation that uses Accelerate in the right conditions.
It allocates data for the output, copies ``y`` into it, and then calls the
:func:`catlas_saxpby` from accelerate.
Let's write out an implementation that uses Accelerate in the right conditions.
It must simply allocate data for the output, copy elements of ``y`` into it,
and then call the :meth:`catlas_saxpby` from accelerate.
.. code-block:: C++
@@ -336,7 +356,17 @@ It allocates data for the output, copies ``y`` into it, and then calls the
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
@@ -359,20 +389,18 @@ It allocates data for the output, copies ``y`` into it, and then calls the
/* INCY = */ 1);
}
For inputs that do not fit the criteria for accelerate, we fall back to
:meth:`Axpby::eval`. With this in mind, let's finish our
:meth:`Axpby::eval_cpu`.
Great! But what about the inputs that do not fit the criteria for accelerate?
Luckily, we can always just direct back to :meth:`Axpby::eval`.
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
.. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
@@ -382,33 +410,35 @@ For inputs that do not fit the criteria for accelerate, we fall back to
return;
}
// Fall back to common back-end if specializations are not available
eval(inputs, outputs);
// Fall back to common backend if specializations are not available
eval(inputs, out);
}
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
We have now hit a milestone! Just this much is enough to run the operation
:meth:`axpby` on a CPU stream!
Implementing the GPU Back-end
If you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
Implementing the GPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Apple silicon devices address their GPUs using the Metal_ shading language, and
GPU kernels in MLX are written using Metal.
Apple silicon devices address their GPUs using the Metal_ shading language, and
all GPU kernels in MLX are written using metal.
.. note::
Here are some helpful resources if you are new to Metal:
Here are some helpful resources if you are new to metal!
* A walkthrough of the metal compute pipeline: `Metal Example`_
* Documentation for metal shading language: `Metal Specification`_
* Using metal from C++: `Metal-cpp`_
Let's keep the GPU kernel simple. We will launch exactly as many threads as
there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the point-wise operation, and update its assigned
element in the output.
Let's keep the GPU algorithm simple. We will launch exactly as many threads
as there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
element in the output.
.. code-block:: C++
@@ -427,14 +457,15 @@ element in the output.
// Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output
out[index] =
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
We then need to instantiate this template for all floating point types and give
each instantiation a unique host name so we can identify it.
each instantiation a unique host name so we can identify the right kernel for
each data type.
.. code-block:: C++
@@ -457,21 +488,29 @@ each instantiation a unique host name so we can identify it.
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
will see later in :ref:`Building with CMake`. In the following example, we
assume that the library ``mlx_ext.metallib`` will always be co-located with
the executable/ shared-library calling the :meth:`register_library` function.
The :meth:`register_library` function takes the library's name and potential
path (or in this case, a function that can produce the path of the metal
library) and tries to load that library if it hasn't already been registered
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
it is important to package your C++ library with the metal library. We will
go over this process in more detail later.
The logic to determine the kernel, set the inputs, resolve the grid dimensions
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
below.
.. code-block:: C++
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
@@ -479,10 +518,10 @@ below.
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
// Allocate output memory
// Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel
// Resolve name of kernel (corresponds to axpby.metal)
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
@@ -494,7 +533,7 @@ below.
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
@@ -503,17 +542,17 @@ below.
size_t nelem = out.size();
// Encode input arrays to kernel
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
// Encode output arrays to kernel
compute_encoder.set_output_array(out, 2);
set_array_buffer(compute_encoder, out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim
// Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
@@ -531,30 +570,33 @@ below.
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
A few things to note about MLX and Metal before moving on. MLX keeps track of
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
associated. We rely on :meth:`d.get_command_encoder` to give us the active
metal compute command encoder instead of building a new one and calling
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
pipelines) to the active command buffer until some specified limit is hit or
the command buffer needs to be flushed for synchronization.
A few things to note about MLX and metal before moving on. MLX keeps track
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
to give us the active metal compute command encoder instead of building a
new one and calling :meth:`compute_encoder->end_encoding` at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
for synchronization. MLX also handles enqueuing and committing the associated
command buffers as needed. We suggest taking a deeper dive into
:class:`metal::Device` if you would like to study this routine further.
Primitive Transforms
^^^^^^^^^^^^^^^^^^^^^
Next, let's add implementations for transformations in a :class:`Primitive`.
These transformations can be built on top of other operations, including the
one we just defined:
Now that we have come this far, let's also learn how to add implementations to
transformations in a :class:`Primitive`. These transformations can be built on
top of our operations, including the one we just defined now. Which then gives
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
.. code-block:: C++
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
array Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
@@ -569,12 +611,12 @@ one we just defined:
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
return multiply(scale_arr, tangents[0], stream());
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
}
}
@@ -583,35 +625,34 @@ one we just defined:
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<int>& /* unused */) {
const array& cotan,
const std::vector<int>& argnums) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
auto scale_arr = array(scale, cotan.dtype());
vjps.push_back(multiply(scale_arr, cotan, stream()));
}
return vjps;
}
Note, a transformation does not need to be fully defined to start using
the :class:`Primitive`.
Finally, you need not have a transformation fully defined to start using your
own :class:`Primitive`.
.. code-block:: C++
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("[Axpby] vmap not implemented.");
throw std::runtime_error("Axpby has no vmap implementation.");
}
Building and Binding
--------------------
Let's look at the overall directory structure first.
Let's look at the overall directory structure first.
| extensions
| ├── axpby
@@ -625,39 +666,40 @@ Let's look at the overall directory structure first.
| └── setup.py
* ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated Python package
* ``extensions/bindings.cpp`` provides Python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
Python bindings
* ``extensions/mlx_sample_extensions`` sets out the structure for the
associated python package
* ``extensions/bindings.cpp`` provides python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
python bindings
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
the Python package
the python package
Binding to Python
^^^^^^^^^^^^^^^^^^
We use nanobind_ to build a Python API for the C++ library. Since bindings for
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
already provided, adding our :meth:`axpby` is simple.
already provided, adding our :meth:`axpby` is simple!
.. code-block:: C++
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
@@ -669,17 +711,17 @@ already provided, adding our :meth:`axpby` is simple.
Returns:
array: ``alpha * x + beta * y``
)");
)pbdoc");
}
Most of the complexity in the above example comes from additional bells and
Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings.
.. warning::
:mod:`mlx.core` must be imported before importing
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
ensure that the casters for :mod:`mlx.core` components like
:mod:`mlx.core` needs to be imported before importing
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
ensure that the casters for :mod:`mlx.core` components like
:class:`mlx.core.array` are available.
.. _Building with CMake:
@@ -687,8 +729,8 @@ whistles such as the literal names and doc-strings.
Building with CMake
^^^^^^^^^^^^^^^^^^^^
Building the C++ extension library only requires that you ``find_package(MLX
CONFIG)`` and then link it to your library.
Building the C++ extension library itself is simple, it only requires that you
``find_package(MLX CONFIG)`` and then link it to your library.
.. code-block:: cmake
@@ -710,12 +752,12 @@ CONFIG)`` and then link it to your library.
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
We also need to build the attached Metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
We also need to build the attached metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
Here is what that looks like in practice:
Here is what that looks like in practice!
.. code-block:: cmake
@@ -737,29 +779,27 @@ Here is what that looks like in practice:
endif()
Finally, we build the nanobind_ bindings
Finally, we build the Pybind11_ bindings
.. code-block:: cmake
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
pybind11_add_module(
mlx_sample_extensions
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()
Building with ``setuptools``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once we have set out the CMake build rules as described above, we can use the
build utilities defined in :mod:`mlx.extension`:
build utilities defined in :mod:`mlx.extension` for a simple build process.
.. code-block:: python
.. code-block:: python
from mlx import extension
from setuptools import setup
@@ -769,50 +809,48 @@ build utilities defined in :mod:`mlx.extension`:
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev":[]},
packages = ["mlx_sample_extensions"],
package_dir = {"": "mlx_sample_extensions"},
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
zip_safe=False,
python_requires=">=3.8",
python_requires=">=3.7",
)
.. note::
We treat ``extensions/mlx_sample_extensions`` as the package directory
even though it only contains a ``__init__.py`` to ensure the following:
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
To build the package, first install the build dependencies with ``pip install
-r requirements.txt``. You can then build inplace for development using
You can build inplace for development using
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
This results in the directory structure:
This will result in a directory structure as follows:
| extensions
| ├── mlx_sample_extensions
| │ ├── __init__.py
| │ ├── libmlx_ext.dylib # C++ extension library
| │ ├── mlx_ext.metallib # Metal library
| │ └── _ext.cpython-3x-darwin.so # Python Binding
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
| ...
When you try to install using the command ``python -m pip install .`` (in
``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
copied along with the Python binding since they are specified as
``package_data``.
When you try to install using the command ``python -m pip install .``
(in ``extensions/``), the package will be installed with the same structure as
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
copied along with the python binding since they are specified as ``package_data``.
Usage
-----
After installing the extension as described above, you should be able to simply
import the Python package and play with it as you would any other MLX operation.
After installing the extension as described above, you should be able to simply
import the python package and play with it as you would any other MLX operation!
Let's look at a simple script and its results:
Let's looks at a simple script and it's results!
.. code-block:: python
@@ -825,7 +863,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")
print(f"c correctness: {mx.all(c == 6.0).item()}")
Output:
@@ -836,12 +874,12 @@ Output:
c correctness: True
Results
^^^^^^^
^^^^^^^^^^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we first defined on the CPU.
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we defined at first on the CPU.
.. code-block:: python
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
@@ -860,7 +898,7 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
alpha = 4.0
beta = 2.0
mx.eval(x, y)
mx.eval((x, y))
def bench(f):
# Warm up
@@ -881,23 +919,30 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
modest improvements right away!
Results:
.. code-block::
Simple axpby: 0.114 s | Custom axpby: 0.109 s
We see some modest improvements right away!
This operation is now good to be used to build other operations, in
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
:meth:`grad`.
:meth:`grad`!
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
The full example code is available in `mlx <code>`_.
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
.. _nanobind: https://nanobind.readthedocs.io/en/latest/
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/

View File

@@ -1,68 +0,0 @@
Metal Debugger
==============
.. currentmodule:: mlx.core
Profiling is a key step for performance optimization. You can build MLX with
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
* Records source during Metal compilation, for later inspection while
debugging.
* Labels Metal objects such as command queues, improving capture readability.
To build with debugging enabled in Python prepend
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
work.
.. note::
To capture a GPU trace you must run the application with
``MTL_CAPTURE_ENABLED=1``.
.. code-block:: python
import mlx.core as mx
a = mx.random.uniform(shape=(512, 512))
b = mx.random.uniform(shape=(512, 512))
mx.eval(a, b)
trace_file = "mlx_trace.gputrace"
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
# that the path trace_file does not already exist.
mx.metal.start_capture(trace_file)
for _ in range(10):
mx.eval(mx.add(a, b))
mx.metal.stop_capture()
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
has a great overview of all operations. Checkout the `Metal debugger
documentation`_ for more information.
.. image:: ../_static/metal_debugger/capture.png
:class: dark-light
Xcode Workflow
--------------
You can skip saving to a path by running within Xcode. First, generate an
Xcode project using CMake.
.. code-block::
mkdir build && cd build
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
open mlx.xcodeproj
Select the ``metal_capture`` example schema and run.
.. image:: ../_static/metal_debugger/schema.png
:class: dark-light
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger

View File

@@ -58,12 +58,10 @@ are the CPU and GPU.
:maxdepth: 1
python/array
python/data_types
python/devices_and_streams
python/ops
python/random
python/transforms
python/fast
python/fft
python/linalg
python/metal
@@ -82,4 +80,3 @@ are the CPU and GPU.
:maxdepth: 1
dev/extensions
dev/metal_debugger

View File

@@ -15,10 +15,10 @@ To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.8
- macOS >= 13.5
- macOS >= 13.3
.. note::
MLX is only available on devices running macOS >= 13.5
MLX is only available on devices running macOS >= 13.3
It is highly recommended to use macOS 14 (Sonoma)
@@ -54,7 +54,7 @@ Build Requirements
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- Xcode >= 15.0 and macOS SDK >= 14.0
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
.. note::
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
@@ -70,13 +70,16 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
.. code-block:: shell
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
pip install "pybind11[global]"
conda install pybind11
brew install pybind11
Then simply build and install MLX using pip:
Then simply build and install it using pip:
.. code-block:: shell
@@ -120,7 +123,7 @@ Create a build directory and run CMake and make:
.. code-block:: shell
mkdir -p build && cd build
cmake .. && make -j
cmake .. && make -j
Run tests with:
@@ -139,7 +142,7 @@ directory as the executable statically linked to ``libmlx.a`` or the
preprocessor constant ``METAL_PATH`` should be defined at build time and it
should point to the path to the built metal library.
.. list-table:: Build Options
.. list-table:: Build Options
:widths: 25 8
:header-rows: 1
@@ -153,56 +156,31 @@ should point to the path to the built metal library.
- OFF
* - MLX_BUILD_METAL
- ON
* - MLX_BUILD_CPU
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
* - MLX_METAL_DEBUG
- OFF
* - MLX_BUILD_SAFETENSORS
- ON
* - MLX_BUILD_GGUF
- ON
.. note::
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
If you have multiple Xcode installations and wish to use
a specific one while building, you can do so by adding the
following environment variable before building
.. code-block:: shell
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
Further, you can use the following command to find out which
Further, you can use the following command to find out which
macOS SDK will be used
.. code-block:: shell
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
and `BUILD_SHARED_LIBS=ON`.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do:
```shell
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=ON \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF
```
Troubleshooting
^^^^^^^^^^^^^^^
Metal not found
~~~~~~~~~~~~~~~
@@ -224,7 +202,7 @@ Then set the active developer directory:
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
x86 Shell
x86 Shell
~~~~~~~~~
.. _build shell:

View File

@@ -10,38 +10,27 @@ Array
array
array.astype
array.at
array.item
array.tolist
array.dtype
array.itemsize
array.nbytes
array.ndim
array.shape
array.size
Dtype
array.abs
array.all
array.any
array.argmax
array.argmin
array.cos
array.cummax
array.cummin
array.cumprod
array.cumsum
array.diag
array.diagonal
array.dtype
array.exp
array.flatten
array.log
array.log10
array.log1p
array.log2
array.logsumexp
array.max
array.mean
array.min
array.moveaxis
array.prod
array.reciprocal
array.reshape
@@ -51,8 +40,6 @@ Array
array.split
array.sqrt
array.square
array.squeeze
array.swapaxes
array.sum
array.transpose
array.T

View File

@@ -1,5 +1,7 @@
.. _data_types:
:orphan:
Data Types
==========
@@ -42,27 +44,9 @@ The default floating point type is ``float32`` and the default integer type is
* - ``int64``
- 8
- 64-bit signed integer
* - ``bfloat16``
- 2
- 16-bit brain float (e8, m7)
* - ``float16``
- 2
- 16-bit IEEE float (e5, m10)
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
* - ``float32``
- 4
- 32-bit float
* - ``complex64``
- 8
- 64-bit complex float
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category.
.. autosummary::
:toctree: _autosummary
Dtype
DtypeCategory
issubdtype

View File

@@ -16,4 +16,3 @@ Devices and Streams
new_stream
set_default_stream
stream
synchronize

View File

@@ -1,14 +0,0 @@
.. _fast:
Fast
====
.. currentmodule:: mlx.core.fast
.. autosummary::
:toctree: _autosummary
rms_norm
layer_norm
rope
scaled_dot_product_attention

View File

@@ -8,7 +8,5 @@ Linear Algebra
.. autosummary::
:toctree: _autosummary
inv
norm
qr
svd

View File

@@ -3,17 +3,12 @@ Metal
.. currentmodule:: mlx.core.metal
.. autosummary::
.. autosummary::
:toctree: _autosummary
is_available
device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
clear_cache
start_capture
stop_capture

View File

@@ -173,7 +173,7 @@ In detail:
:toctree: _autosummary
value_and_grad
quantize
checkpoint
.. toctree::

View File

@@ -15,28 +15,23 @@ Layers
BatchNorm
Conv1d
Conv2d
Conv3d
Dropout
Dropout2d
Dropout3d
Embedding
GELU
GroupNorm
GRU
InstanceNorm
LayerNorm
Linear
LSTM
MaxPool1d
MaxPool2d
Mish
MultiHeadAttention
PReLU
QuantizedEmbedding
QuantizedLinear
RMSNorm
ReLU
RNN
RoPE
SELU
Sequential
@@ -45,4 +40,4 @@ Layers
Softshrink
Step
Transformer
Upsample
Upsample

View File

@@ -30,7 +30,6 @@ Module
Module.named_modules
Module.parameters
Module.save_weights
Module.set_dtype
Module.train
Module.trainable_parameters
Module.unfreeze

View File

@@ -5,14 +5,13 @@ Operations
.. currentmodule:: mlx.core
.. autosummary::
.. autosummary::
:toctree: _autosummary
abs
add
addmm
all
allclose
allclose
any
arange
arccos
@@ -20,39 +19,25 @@ Operations
arcsin
arcsinh
arctan
arctan2
arctanh
argmax
argmin
argpartition
argsort
array_equal
as_strided
atleast_1d
atleast_2d
atleast_3d
bitwise_and
bitwise_or
bitwise_xor
block_masked_mm
block_sparse_mm
broadcast_to
ceil
clip
concatenate
conj
conjugate
convolve
conv1d
conv2d
conv_general
cos
cosh
cummax
cummin
cumprod
cumsum
degrees
dequantize
diag
diagonal
@@ -62,7 +47,6 @@ Operations
erf
erfinv
exp
expm1
expand_dims
eye
flatten
@@ -74,12 +58,10 @@ Operations
identity
inner
isclose
isinf
isnan
isneginf
isposinf
issubdtype
left_shift
isneginf
isinf
less
less_equal
linspace
@@ -97,28 +79,22 @@ Operations
max
maximum
mean
meshgrid
min
minimum
moveaxis
multiply
negative
not_equal
ones
ones_like
outer
partition
pad
power
prod
quantize
quantized_matmul
radians
reciprocal
remainder
repeat
reshape
right_shift
round
rsqrt
save
@@ -137,7 +113,6 @@ Operations
square
squeeze
stack
std
stop_gradient
subtract
sum

View File

@@ -1,7 +1,5 @@
.. _optimizers:
.. currentmodule:: mlx.optimizers
Optimizers
==========
@@ -36,8 +34,3 @@ model's parameters and the **optimizer state**.
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers
.. autosummary::
:toctree: _autosummary
clip_grad_norm

View File

@@ -38,7 +38,6 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
gumbel
key
normal
multivariate_normal
randint
seed
split

View File

@@ -17,3 +17,4 @@ Transforms
jvp
vjp
vmap
checkpoint

View File

@@ -19,5 +19,3 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_flatten
tree_unflatten
tree_map
tree_map_with_path
tree_reduce

View File

@@ -40,7 +40,7 @@ getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
depth. See the following sections for more information on :ref:`automatic
differentiation <auto diff>` and :ref:`automatic vectorization <vmap>`.
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.

View File

@@ -18,7 +18,7 @@ describe below.
Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Lazy evaluation lets us record a compute graph without actually doing any
Lazy evaluation let's us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations.

View File

@@ -49,7 +49,7 @@ it will be added. You can load the array with:
.. code-block:: shell
>>> mx.load("array.npy")
>>> mx.load("array.npy", a)
array([1], dtype=float32)
Here's an example of saving several arrays to a single file:

View File

@@ -8,4 +8,3 @@ endfunction(build_example)
build_example(tutorial.cpp)
build_example(linear_regression.cpp)
build_example(logistic_regression.cpp)
build_example(metal_capture.cpp)

View File

@@ -1,31 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
// To use Metal debugging and profiling:
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
// 2. Run with MTL_CAPTURE_ENABLED=1.
metal::start_capture("mlx_trace.gputrace");
// Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each
// stream's command queue.
auto s2 = new_stream(Device::gpu);
auto s3 = new_stream(Device::gpu);
auto a = arange(1.f, 10.f, 1.f, float32, s2);
auto b = arange(1.f, 10.f, 1.f, float32, s3);
auto x = add(a, a, s2);
auto y = add(b, b, s3);
// The multiply will happen on the default stream.
std::cout << multiply(x, y) << std::endl;
metal::stop_capture();
}

View File

@@ -89,8 +89,8 @@ void automatic_differentiation() {
// dfdx is 2 * x
// Get the second derivative by composing grad with grad
auto d2fdx2 = grad(grad(fn))(x);
// d2fdx2 is 2
auto df2dx2 = grad(grad(fn))(x);
// df2dx2 is 2
}
int main() {

View File

@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.27)
project(_ext LANGUAGES CXX)
project(mlx_sample_extensions LANGUAGES CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_CXX_STANDARD 17)
@@ -11,12 +11,8 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
# ----------------------------- Dependencies -----------------------------
find_package(MLX CONFIG REQUIRED)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
# ----------------------------- Extensions -----------------------------
@@ -42,6 +38,7 @@ target_link_libraries(mlx_ext PUBLIC mlx)
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
@@ -57,15 +54,13 @@ if(MLX_BUILD_METAL)
endif()
# ----------------------------- Python Bindings -----------------------------
nanobind_add_module(
_ext
NB_STATIC STABLE_ABI LTO NOMINSIZE
NB_DOMAIN mlx
# ----------------------------- Pybind -----------------------------
pybind11_add_module(
mlx_sample_extensions
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(_ext PRIVATE mlx_ext)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()

View File

@@ -1,24 +0,0 @@
## Build
```
pip install -e .
```
For faster builds during development, you can also pre-install the requirements:
```
pip install -r requirements.txt
```
And then run:
```
python setup.py build_ext -j8 --inplace
```
## Test
```
python test.py
`

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <iostream>
@@ -43,7 +43,7 @@ array axpby(
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = issubdtype(promoted_dtype, float32)
auto out_dtype = is_floating_point(promoted_dtype)
? promoted_dtype
: promote_types(promoted_dtype, float32);
@@ -61,7 +61,7 @@ array axpby(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta),
std::make_unique<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
@@ -106,12 +106,12 @@ void axpby_impl(
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
std::vector<array>& out_arr) {
auto out = out_arr[0];
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
@@ -150,7 +150,11 @@ void axpby_impl_accelerate(
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
@@ -176,11 +180,11 @@ void axpby_impl_accelerate(
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
std::vector<array>& outarr) {
auto out = outarr[0];
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
@@ -191,7 +195,7 @@ void Axpby::eval_cpu(
}
// Fall back to common backend if specializations are not available
eval(inputs, outputs);
eval(inputs, outarr);
}
#else // Accelerate not available
@@ -199,8 +203,8 @@ void Axpby::eval_cpu(
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
eval(inputs, outputs);
std::vector<array>& out) {
eval(inputs, out);
}
#endif
@@ -214,12 +218,12 @@ void Axpby::eval_cpu(
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
std::vector<array>& outarr) {
// Prepare inputs
auto out = outarr[0];
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
auto& out = outputs[0];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
@@ -257,7 +261,7 @@ void Axpby::eval_gpu(
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
@@ -266,11 +270,11 @@ void Axpby::eval_gpu(
size_t nelem = out.size();
// Encode input arrays to kernel
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
// Encode output arrays to kernel
compute_encoder.set_output_array(out, 2);
set_array_buffer(compute_encoder, out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
@@ -296,7 +300,7 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
#else // Metal is not available
@@ -368,4 +372,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -33,7 +33,7 @@ array axpby(
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta) {};
: Primitive(stream), alpha_(alpha), beta_(beta){};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
@@ -42,9 +42,9 @@ class Axpby : public Primitive {
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out)
override;
/** The Jacobian-vector product. */
@@ -83,7 +83,7 @@ class Axpby : public Primitive {
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
void eval(const std::vector<array>& inputs, std::vector<array>& out);
};
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -19,7 +19,7 @@ template <typename T>
uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
out[index] =
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
@@ -31,30 +31,30 @@ template <typename T>
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
out[index] =
out[index] =
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
}
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \
[[kernel]] void axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] \
[[kernel]] void axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);

View File

@@ -1,31 +1,31 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/variant.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "axpby/axpby.h"
namespace nb = nanobind;
using namespace nb::literals;
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
NB_MODULE(_ext, m) {
m.doc() = "Sample extension for MLX";
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
R"(
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors element-wise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
@@ -37,5 +37,5 @@ NB_MODULE(_ext, m) {
Returns:
array: ``alpha * x + beta * y``
)");
}
)pbdoc");
}

View File

@@ -2,4 +2,4 @@
import mlx.core as mx
from ._ext import axpby
from .mlx_sample_extensions import *

View File

@@ -1,8 +1,3 @@
[build-system]
requires = [
"setuptools>=42",
"cmake>=3.24",
"mlx>=0.9.0",
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
]
build-backend = "setuptools.build_meta"
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"]
build-backend = "setuptools.build_meta"

View File

@@ -1,4 +0,0 @@
setuptools>=42
cmake>=3.24
mlx>=0.9.0
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4

View File

@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
# Copyright © 2023 Apple Inc.
from setuptools import setup
@@ -9,11 +9,11 @@ if __name__ == "__main__":
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
cmdclass={"build_ext": extension.CMakeBuild},
packages=["mlx_sample_extensions"],
package_dir={"": "."},
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
extras_require={"dev": []},
zip_safe=False,
python_requires=">=3.8",
)

View File

@@ -1,10 +0,0 @@
import mlx.core as mx
from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")

View File

@@ -19,16 +19,11 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
)
if (MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
elseif(MLX_BUILD_CPU)
else()
target_sources(
mlx
PRIVATE

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
Buffer(void* ptr) : ptr_(ptr) {};
Buffer(void* ptr) : ptr_(ptr){};
// Get the raw data pointer from the buffer
void* raw_ptr();

View File

@@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <functional>
#include "mlx/array.h"
@@ -11,6 +12,16 @@ namespace mlx::core {
namespace {
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
std::vector<size_t> strides(shape.size());
size_t cum_prod = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = cum_prod;
cum_prod *= shape[i];
}
return {cum_prod, strides};
}
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
bool in_tracing() {
@@ -25,11 +36,22 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
init(&cval);
}
array::array(
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
shape,
dtype,
std::move(primitive),
inputs)) {}
array::array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
std::vector<array>&& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
std::move(shape),
dtype,
@@ -37,16 +59,15 @@ array::array(
std::move(inputs))) {}
std::vector<array> array::make_arrays(
std::vector<std::vector<int>> shapes,
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs) {
std::vector<array> outputs;
for (size_t i = 0; i < shapes.size(); ++i) {
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
for (int i = 0; i < shapes.size(); ++i) {
outputs.push_back(array(shapes[i], dtypes[i], primitive, inputs));
}
// For each node in |outputs|, its siblings are the other nodes.
for (size_t i = 0; i < outputs.size(); ++i) {
for (int i = 0; i < outputs.size(); ++i) {
auto siblings = outputs;
siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i);
@@ -71,10 +92,10 @@ array::array(std::initializer_list<int> data, Dtype dtype)
/* Build an array from a shared buffer */
array::array(
allocator::Buffer data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
deleter_t deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
set_data(data, deleter);
}
@@ -83,22 +104,18 @@ void array::detach() {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->depth = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
array_desc_->depth = 0;
array_desc_->primitive = nullptr;
}
void array::eval() {
// Ensure the array is ready to be read
if (status() == Status::scheduled) {
event().wait();
set_status(Status::available);
} else if (status() == Status::unscheduled) {
mlx::core::eval({*this});
}
mlx::core::eval({*this});
}
bool array::is_tracer() const {
@@ -147,116 +164,51 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
}
void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = other.strides();
array_desc_->flags = other.flags();
array_desc_->data_size = other.data_size();
array_desc_->data_ptr = other.array_desc_->data_ptr;
}
array::~array() {
if (array_desc_ == nullptr) {
return;
}
// Ignore arrays that will be detached
if (status() != array::Status::unscheduled) {
return;
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;
// If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1)
// then there are no more external references
do_detach &= (array_desc_.use_count() == (n + 1));
for (auto& s : siblings()) {
do_detach &= (s.array_desc_.use_count() == n);
if (!do_detach) {
break;
}
}
if (do_detach) {
for (auto& s : siblings()) {
for (auto& ss : s.siblings()) {
ss.array_desc_ = nullptr;
}
s.array_desc_->siblings.clear();
}
}
}
}
void array::ArrayDesc::init() {
strides.resize(shape.size());
size = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = size;
size *= shape[i];
}
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
init();
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
: shape(shape), dtype(dtype) {
std::tie(size, strides) = cum_prod(shape);
}
array::ArrayDesc::ArrayDesc(
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs)
: shape(std::move(shape)),
const std::vector<array>& inputs)
: shape(shape),
dtype(dtype),
status(Status::unscheduled),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
init();
inputs(inputs) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayDesc::~ArrayDesc() {
// When an array description is destroyed it will delete a bunch of arrays
// that may also destory their corresponding descriptions and so on and so
// forth.
//
// This calls recursively the destructor and can result in stack overflow, we
// instead put them in a vector and destroy them one at a time resulting in a
// max stack depth of 2.
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
for (array& a : inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
while (!for_deletion.empty()) {
// top is going to be deleted at the end of the block *after* the arrays
// with inputs have been moved into the vector
auto top = std::move(for_deletion.back());
for_deletion.pop_back();
for (array& a : top->inputs) {
if (a.array_desc_.use_count() == 1) {
for_deletion.push_back(std::move(a.array_desc_));
}
}
array::ArrayDesc::ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs)
: shape(std::move(shape)),
dtype(dtype),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)

View File

@@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
@@ -9,7 +8,6 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
namespace mlx::core {
@@ -33,7 +31,7 @@ class array {
template <typename It>
array(
It data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>());
@@ -49,13 +47,13 @@ class array {
template <typename T>
array(
std::initializer_list<T> data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */
array(
allocator::Buffer data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
deleter_t deleter = allocator::free);
@@ -114,15 +112,6 @@ class array {
return array_desc_->strides;
};
/**
* Get the stride of the corresponding dimension.
*
* This function supports negative indexing and provides
* bounds checking. */
size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
};
/** Get the arrays data type. */
Dtype dtype() const {
return array_desc_->dtype;
@@ -183,16 +172,22 @@ class array {
* API may change.
*/
array(
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
std::vector<array>&& inputs);
static std::vector<array> make_arrays(
std::vector<std::vector<int>> shapes,
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
/** A unique identifier for an array. */
@@ -209,7 +204,7 @@ class array {
allocator::Buffer buffer;
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d) {};
: buffer(buffer), d(d){};
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;
@@ -261,11 +256,6 @@ class array {
return array_desc_->siblings;
};
/** The array's siblings. */
std::vector<array>& siblings() {
return array_desc_->siblings;
};
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
array_desc_->position = position;
@@ -283,6 +273,11 @@ class array {
return outputs;
};
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
uint16_t graph_depth() const {
return array_desc_->depth;
}
/** Detach the array from the graph. */
void detach();
@@ -319,27 +314,9 @@ class array {
return static_cast<T*>(array_desc_->data_ptr);
};
enum Status { unscheduled, scheduled, available };
bool is_available() const {
return status() == Status::available;
}
const Status status() const {
return array_desc_->status;
}
void set_status(Status s) const {
array_desc_->status = s;
}
// Get the array's shared event
Event& event() const {
return array_desc_->event;
}
// Attach an event to a not yet evaluated array
void attach_event(Event e) const {
array_desc_->event = std::move(e);
// Check if the array has been evaluated
bool is_evaled() const {
return array_desc_->data != nullptr;
}
// Mark the array as a tracer array (true) or not.
@@ -367,21 +344,12 @@ class array {
void copy_shared_buffer(const array& other);
void move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_;
}
~array();
private:
// Initialize the arrays data
template <typename It>
@@ -392,12 +360,7 @@ class array {
std::vector<size_t> strides;
size_t size;
Dtype dtype;
std::shared_ptr<Primitive> primitive;
Status status;
// An event on the array used for synchronization
Event event;
std::shared_ptr<Primitive> primitive{nullptr};
// Indicates an array is being used in a graph transform
// and should not be detached from the graph
@@ -405,7 +368,7 @@ class array {
// This is a shared pointer so that *different* arrays
// can share the underlying data buffer.
std::shared_ptr<Data> data;
std::shared_ptr<Data> data{nullptr};
// Properly offset data pointer
void* data_ptr{nullptr};
@@ -425,26 +388,29 @@ class array {
// The arrays position in the output list
uint32_t position{0};
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
// The depth of the array in the graph.
uint16_t depth{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc(
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
const std::vector<array>& inputs);
~ArrayDesc();
private:
// Initialize size, strides, and other metadata
void init();
explicit ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs);
};
// The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs
// and the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_;
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
};
template <typename T>
@@ -456,9 +422,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
template <typename It>
array::array(
It data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
init(data);
}
@@ -475,9 +441,9 @@ array::array(
template <typename T>
array::array(
std::initializer_list<T> data,
std::vector<int> shape,
const std::vector<int>& shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
if (data.size() != size()) {
throw std::invalid_argument(
"Data size and provided shape mismatch in array construction.");
@@ -499,11 +465,10 @@ T array::item() const {
if (size() != 1) {
throw std::invalid_argument("item can only be called on arrays of size 1.");
}
if (status() == Status::unscheduled) {
if (!is_evaled()) {
throw std::invalid_argument(
"item() const can only be called on evaled arrays");
}
const_cast<array*>(this)->eval();
return *data<T>();
}
@@ -553,15 +518,4 @@ void array::init(It src) {
}
}
/* Utilities for determining whether a template parameter is array. */
template <typename T>
inline constexpr bool is_array_v =
std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
template <typename... T>
inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
template <typename... T>
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
@@ -196,40 +196,6 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
return matmul_bnns_general(a_pre, b_pre, out);
}
template <typename T>
inline void mask_matrix(
T* data,
const bool* mask,
int tile_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str) {
int tX = (X + tile_size - 1) / tile_size;
int tY = (Y + tile_size - 1) / tile_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
if (!do_mask) {
int loc_x = i * tile_size;
int loc_y = j * tile_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
int size_x = std::min(tile_size, X - loc_x);
int size_y = std::min(tile_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
}
}
}
}
}
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {

View File

@@ -31,18 +31,13 @@ DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(BlockSparseQMM)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
@@ -73,13 +68,10 @@ DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
@@ -195,26 +187,6 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
if (a.is_donatable()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
int size = a.data_size();
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -325,7 +297,7 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
throw std::invalid_argument(
@@ -334,19 +306,6 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpm1f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -392,7 +351,7 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
throw std::invalid_argument(

View File

@@ -10,65 +10,78 @@
namespace mlx::core {
namespace {
template <typename T, typename VT>
struct MinReduction {
T operator()(const T& a, const T& b) {
return std::min(a, b);
}
VT operator()(VT a, VT b) {
return simd_min(a, b);
}
};
template <typename T, typename VT>
struct MaxReduction {
T operator()(const T& a, const T& b) {
return std::max(a, b);
}
VT operator()(VT a, VT b) {
return simd_max(a, b);
}
};
template <typename T, typename VT>
struct SumReduction {
T operator()(const T& a, const T& b) {
return a + b;
}
VT operator()(VT a, VT b) {
return a + b;
}
};
template <typename T, typename VT, int N, typename Reduction>
struct StridedReduce {
void operator()(const T* x, T* accum, int size, size_t stride) {
Reduction op;
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = op((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = op(*a, *x);
a++;
x++;
}
template <typename T, typename VT, int N>
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
VT val = (*(VT*)x);
*(VT*)a += val;
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a++ += *x++;
}
}
};
}
} // namespace
// TODO: Add proper templates for the strided reduce algorithm so we don't have
// to write max/min/sum etc.
template <typename T, typename VT, int N>
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::max(*a, *x);
a++;
x++;
}
}
}
template <typename T, typename VT, int N>
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::min(*a, *x);
a++;
x++;
}
}
}
template <typename T, typename VT, int N>
void _vectorized_sum(const T* x, T* accum, int size) {
VT _sum = {0};
while (size >= N) {
_sum += (*(VT*)x);
x += N;
size -= N;
}
T sum = _sum[0];
for (int i = 1; i < N; i++) {
sum += _sum[i];
}
*accum += sum;
}
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
@@ -81,11 +94,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
0,
StridedReduce<
float,
simd_float16,
16,
SumReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_sum<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float acc;
vDSP_sve((const float*)x, 1, &acc, size);
@@ -99,11 +111,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
-std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MaxReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_max<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float max;
vDSP_maxv((const float*)x, 1, &max, size);
@@ -117,11 +128,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MinReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_min<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
[](const auto* x, auto* accum, int size) {
float min;
vDSP_minv((const float*)x, 1, &min, size);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <limits>
@@ -201,7 +201,7 @@ struct NeonFp16SimdOps {
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N>
template <typename T, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
@@ -218,21 +218,13 @@ void softmax(const array& in, array& out) {
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
size_t s = M;
while (s >= N) {
VT vals;
if constexpr (std::is_same<T, AccT>::value) {
vals = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vals[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vmaximum = ops.max(vals, vmaximum);
vmaximum = ops.max(ops.load(current_in_ptr), vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = ops.reduce_max(vmaximum);
T maximum = ops.reduce_max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
maximum = std::max(maximum, *current_in_ptr);
current_in_ptr++;
}
@@ -242,29 +234,18 @@ void softmax(const array& in, array& out) {
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
VT vexp;
if constexpr (std::is_same<T, AccT>::value) {
vexp = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vexp = ops.exp(ops.sub(vexp, maximum));
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, vexp);
}
VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum));
ops.store(current_out_ptr, vexp);
*(VT*)current_out_ptr = vexp;
vnormalizer = ops.add(vnormalizer, vexp);
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = ops.reduce_add(vnormalizer);
T normalizer = ops.reduce_add(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if (std::is_same<T, AccT>::value) {
*current_out_ptr = _exp;
}
T _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = _exp;
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
@@ -273,33 +254,14 @@ void softmax(const array& in, array& out) {
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
} else {
VT vexp;
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
for (int i = 0; i < N; ++i) {
current_out_ptr[i] = vexp[i];
}
current_in_ptr += N;
}
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
*current_out_ptr *= normalizer;
current_out_ptr++;
}
}
@@ -346,29 +308,15 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
"Softmax is defined only for floating point types");
break;
case float32:
softmax<
float,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
softmax<float, simd_float16, AccelerateSimdOps<float, simd_float16>, 16>(
in, out);
break;
case float16:
if (precise_) {
softmax<
float16_t,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
}
softmax<
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
break;
case bfloat16:
eval(inputs, out);

View File

@@ -37,15 +37,14 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
@@ -54,8 +53,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)

View File

@@ -179,16 +179,18 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
if (is_floating_point(out.dtype())) {
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
}
} else {
throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with"
@@ -236,82 +238,4 @@ void Subtract::eval(const std::vector<array>& inputs, array& out) {
binary(a, b, out, detail::Subtract());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd());
break;
case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr());
break;
case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor());
break;
case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift());
break;
case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift());
break;
}
}
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[arctan2] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[arctan2] Cannot compute inverse tangent for arrays"
" with non floating point type.");
}
}
} // namespace mlx::core

View File

@@ -1,347 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (!in.flags().row_contiguous) {
// Just ensuring that inputs[0] came from the ops which would ensure the
// input is row contiguous.
throw std::runtime_error(
"AsStrided must be used with row contiguous arrays only.");
}
// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
c *= shape_[j];
}
auto flags = in.flags();
// TODO: Compute the contiguous flag in a better way cause now we are
// unnecessarily strict.
flags.contiguous = row_contiguous || col_contiguous;
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
// There is no easy way to compute the actual data size so we use out.size().
// The contiguous flag will almost certainly not be set so no code should
// rely on data_size anyway.
size_t data_size = out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
std::vector<size_t> strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {
numel *= inputs[0].shape(ax);
}
if (inverted_) {
numel = 1.0 / numel;
}
switch (out.dtype()) {
case bool_:
*out.data<bool>() = static_cast<bool>(numel);
break;
case uint8:
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
break;
case uint16:
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
break;
case uint32:
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
break;
case uint64:
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
break;
case int8:
*out.data<int8_t>() = static_cast<int8_t>(numel);
break;
case int16:
*out.data<int16_t>() = static_cast<int16_t>(numel);
break;
case int32:
*out.data<int32_t>() = static_cast<int32_t>(numel);
break;
case int64:
*out.data<int64_t>() = static_cast<int64_t>(numel);
break;
case float16:
*out.data<float16_t>() = static_cast<float16_t>(numel);
break;
case float32:
*out.data<float>() = static_cast<float>(numel);
break;
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;
}
}
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
}
// Special case for scalars
if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
}
// Firstly let's collapse all the contiguous dimensions of the input
auto [shape, _strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
std::vector<size_t> out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
int N = out.shape(i);
if (j < shape.size() && shape[j] % N == 0) {
shape[j] /= N;
out_strides.push_back(shape[j] * strides[j]);
j += (shape[j] == 1);
} else if (N == 1) {
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
out_strides.push_back(out_strides.back());
} else {
copy_necessary = true;
break;
}
}
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
array& out) {
auto flags = in.flags();
if (flags.row_contiguous) {
// For row contiguous reshapes:
// - Shallow copy the buffer
// - If reshaping into a vector (all singleton dimensions except one) it
// becomes col contiguous again.
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Split::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto compute_new_flags = [](const auto& shape,
const auto& strides,
size_t in_data_size,
auto flags) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in_data_size) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
return std::pair<decltype(flags), size_t>{flags, data_size};
};
std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < indices.size(); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
outputs[i].copy_shared_buffer(
in, in.strides(), new_flags, data_size, offset);
}
}
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
const array& in) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
copy_needed |= strides_[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void Slice::shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
return std::make_tuple(data_offset, inp_strides);
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::vector<size_t> out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
}
// Conditions for {row/col}_contiguous
// - array must be contiguous (no gaps)
// - underlying buffer size should have the same size as the array
// - cumulative product of shapes is equal to the strides (we can ignore axes
// with size == 1)
// - in the forward direction (column contiguous)
// - in the reverse direction (row contiguous)
// - vectors are both row and col contiguous (hence if both row/col are
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
f_stride *= out.shape(i);
flags.row_contiguous &=
(out_strides[ri] == b_stride || out.shape(ri) == 1);
b_stride *= out.shape(ri);
}
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -1,7 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -82,27 +81,13 @@ std::string build_lib_name(
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (auto& x : inputs) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
@@ -126,102 +111,4 @@ std::string build_lib_name(
return os.str();
}
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape) {
bool contiguous = true;
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (const auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
} else if (non_scalar_inputs == 0 && !shape.empty()) {
contiguous = false;
}
return contiguous;
}
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers /* = false */) {
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Correct size
// - Not a scalar
// - Donatable
// - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o++].move_shared_buffer(in);
} else {
outputs[o++].copy_shared_buffer(in);
}
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
} else {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
}
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
}
} // namespace mlx::core

View File

@@ -53,18 +53,4 @@ inline bool is_scalar(const array& x) {
return x.ndim() == 0;
}
// Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape);
// Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::vector<array>& inputs_,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous,
bool move_buffers = false);
} // namespace mlx::core

View File

@@ -52,25 +52,8 @@ void* compile(
return nullptr;
}
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255
// characters. Clip file name with a little extra room and append a 16
// character hash.
constexpr int max_file_name_length = 245;
if (kernel_name.size() > max_file_name_length) {
std::ostringstream file_name;
file_name
<< std::string_view(kernel_name).substr(0, max_file_name_length - 16);
auto file_id = std::hash<std::string>{}(kernel_name);
file_name << "_" << std::hex << std::setw(16) << file_id << std::dec;
kernel_file_name = file_name.str();
} else {
kernel_file_name = kernel_name;
}
std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_file_name << ".so";
shared_lib_name << "lib" << kernel_name << ".so";
auto shared_lib_path = get_temp_file(shared_lib_name.str());
bool lib_exists = false;
{
@@ -81,7 +64,7 @@ void* compile(
if (!lib_exists) {
// Open source file and write source code to it
std::ostringstream source_file_name;
source_file_name << kernel_file_name << ".cpp";
source_file_name << kernel_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str());
std::ofstream source_file(source_file_path);
@@ -265,7 +248,28 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
bool contiguous = compiled_check_contiguity(inputs, shape);
bool contiguous = true;
{
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
}
}
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
@@ -338,8 +342,56 @@ void Compiled::eval_cpu(
fn_ptr = compile(kernel_name, kernel.str());
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);
// Allocate space for the outputs possibly with input donation
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
for (auto& x : outputs) {
args.push_back(x.data<void>());

View File

@@ -38,15 +38,11 @@ void slow_conv_1D(
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1];
const size_t in_stride_C = in.strides()[2];
@@ -61,36 +57,35 @@ void slow_conv_1D(
for (int n = 0; n < N; ++n) {
for (int oh = 0; oh < oH; ++oh) {
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;
for (int o = 0; o < O; ++o) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
auto ih_div = std::div(ih, in_dilation[0]);
auto ih_div = std::div(ih, in_dilation[0]);
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[(c % C_per_group) * wt_stride_C]);
} // c
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = 0; c < C; ++c) {
r += static_cast<float>(
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[c * wt_stride_C]);
} // c
} // ih check
} // wh
} // ih check
} // wh
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o
} // g
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o
} // oh
in_ptr += in_stride_N;
out_ptr += out_stride_N;
} // n
}
@@ -310,296 +305,6 @@ void slow_conv_2D(
} // n
}
template <typename T>
void slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const T* st_wt_ptr = wt.data<T>();
const T* st_in_ptr = in.data<T>();
T* st_out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim
const int oD = out.shape(1); // Output spatial dim
const int oH = out.shape(2); // Output spatial dim
const int oW = out.shape(3); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(4); // In channels
const int wD = wt.shape(1); // Weight spatial dim
const int wH = wt.shape(2); // Weight spatial dim
const int wW = wt.shape(3); // Weight spatial dim
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_D = in.strides()[1];
const size_t in_stride_H = in.strides()[2];
const size_t in_stride_W = in.strides()[3];
const size_t in_stride_C = in.strides()[4];
const size_t wt_stride_O = wt.strides()[0];
const size_t wt_stride_D = wt.strides()[1];
const size_t wt_stride_H = wt.strides()[2];
const size_t wt_stride_W = wt.strides()[3];
const size_t wt_stride_C = wt.strides()[4];
const size_t out_stride_N = out.strides()[0];
const size_t out_stride_D = out.strides()[1];
const size_t out_stride_H = out.strides()[2];
const size_t out_stride_W = out.strides()[3];
const size_t out_stride_O = out.strides()[4];
bool is_idil_one =
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1;
auto pt_conv_no_checks = [&](const T* in_ptr,
const T* wt_ptr,
T* out_ptr,
int od,
int oh,
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wd = 0; wd < wD; ++wd) {
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wd_flip = flip ? wD - wd - 1 : wd;
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int id = id_base + wd_flip * wt_dilation[0];
int ih = ih_base + wh_flip * wt_dilation[1];
int iw = iw_base + ww_flip * wt_dilation[2];
const T* wt_ptr_pt =
wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
} // ww
} // wh
} // wd
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
};
int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_h = flip ? -wt_dilation[1] : wt_dilation[1];
int jump_w = flip ? -wt_dilation[2] : wt_dilation[2];
int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0);
int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0);
int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2];
int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2];
std::vector<int> base_d(f_out_jump_d);
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_d; ++i) {
int id_loop = i * wt_strides[0] - padding[0] + init_d;
int wd_base = 0;
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
wd_base++;
id_loop += jump_d;
}
base_d[i] = wd_base;
}
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
wh_base++;
ih_loop += jump_h;
}
base_h[i] = wh_base;
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
ww_base++;
iw_loop += jump_w;
}
base_w[j] = ww_base;
}
auto pt_conv_all_checks = [&](const T* in_ptr,
const T* wt_ptr,
T* out_ptr,
int od,
int oh,
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
int wd_base = base_d[od % f_out_jump_d];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wd_flip = flip ? wD - wd - 1 : wd;
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int id = id_base + wd_flip * wt_dilation[0];
int ih = ih_base + wh_flip * wt_dilation[1];
int iw = iw_base + ww_flip * wt_dilation[2];
if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
iw < iW) {
const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
wh * wt_stride_H + ww * wt_stride_W;
int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;
const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
ih_dil * in_stride_H + iw_dil * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
} // iD, ih, iw check
} // ww
} // wh
} // wd
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
};
int oD_border_0 = 0;
int oD_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
int oD_border_2 = std::max(
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
int oD_border_3 = oD;
int oH_border_0 = 0;
int oH_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 =
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
// Case 1: od might put us out of bounds
for (int od = oD_border_0; od < oD_border_1; ++od) {
for (int oh = 0; oh < oH; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
} // od
// Case 2: od in bounds
for (int od = oD_border_1; od < oD_border_2; ++od) {
// Case 2.1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
// Case 2.2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case 2.2.1: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
// Case 2.2.2: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
// Case 2.2.3: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
// Case 2.3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
} // od
// Case 3: od might put us out of bounds
for (int od = oD_border_2; od < oD_border_3; ++od) {
for (int oh = 0; oh < oH; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
} // od
st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N;
} // n
}
void dispatch_slow_conv_1D(
const array& in,
const array& wt,
@@ -648,30 +353,6 @@ void dispatch_slow_conv_2D(
}
}
void dispatch_slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (in.dtype() == float32) {
return slow_conv_3D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == float16) {
return slow_conv_3D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == bfloat16) {
return slow_conv_3D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
}
}
///////////////////////////////////////////////////////////////////////////////
// Explicit gemm conv
///////////////////////////////////////////////////////////////////////////////
@@ -685,15 +366,11 @@ void explicit_gemm_conv_1D_cpu(
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;
auto conv_dtype = float32;
// Pad input
@@ -725,11 +402,6 @@ void explicit_gemm_conv_1D_cpu(
in_padded.strides()[1],
in_padded.strides()[2]};
auto flags = in_padded.flags();
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
std::swap(strided_shape[2], strided_shape[3]);
std::swap(strided_strides[2], strided_strides[3]);
}
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
@@ -744,19 +416,7 @@ void explicit_gemm_conv_1D_cpu(
auto gemm_wt = wt;
auto gemm_out = out;
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
array wt_transpose(
{wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
wt_transpose.copy_shared_buffer(
wt,
{wt.strides(0), wt.strides(2), wt.strides(1)},
wt.flags(),
wt.size(),
0);
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
copy(wt_transpose, gemm_wt, CopyType::General);
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
@@ -768,29 +428,27 @@ void explicit_gemm_conv_1D_cpu(
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
for (int g = 0; g < groups; ++g) {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O_per_group, // N
C_per_group * wH, // K
1.0f, // alpha
in_strided.data<float>() + g * C_per_group * wH, // A
wH * C, // lda
gemm_wt.data<float>() + g * O_per_group * C_per_group * wH, // B
wH * C_per_group, // ldb
0.0f, // beta
gemm_out.data<float>() + g * O_per_group, // C
O // ldc
);
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
@@ -896,131 +554,6 @@ void explicit_gemm_conv_2D_cpu(
}
}
void explicit_gemm_conv_ND_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = std::vector<int>(
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(-1); // In channels
const auto wDim = std::vector<int>(
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape(in.shape().size());
padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
}
padded_shape.back() = C;
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
// Pick input slice from padded
size_t data_offset = 0;
for (size_t i = 0; i < padding.size(); i++) {
data_offset += padding[i] * in_padded.strides()[i + 1];
}
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
strided_shape.front() = N;
for (size_t i = 0; i < oDim.size(); i++) {
strided_shape[i + 1] = oDim[i];
}
for (size_t i = 0; i < wDim.size(); i++) {
strided_shape[i + 1 + oDim.size()] = wDim[i];
}
strided_shape.back() = C;
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
strided_strides[0] = in_padded.strides()[0];
for (size_t i = 0; i < wt_strides.size(); i++) {
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
}
for (size_t i = 1; i < in_padded.strides().size(); i++) {
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
}
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N, C};
for (const auto& o : oDim) {
strided_reshape[0] *= o;
}
for (const auto& w : wDim) {
strided_reshape[1] *= w;
}
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
///////////////////////////////////////////////////////////////////////////////
// Conv routing
///////////////////////////////////////////////////////////////////////////////
@@ -1056,19 +589,6 @@ void conv_2D_cpu(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
void conv_3D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
} // namespace
void Convolution::eval(const std::vector<array>& inputs, array& out) {
@@ -1077,20 +597,8 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto& wt = inputs[1];
// 3D convolution
if (in.ndim() == (3 + 2)) {
return conv_3D_cpu(
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
}
// 2D convolution
else if (in.ndim() == (2 + 2)) {
if (in.ndim() == (2 + 2)) {
return conv_2D_cpu(
in,
wt,

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <numeric>
@@ -25,196 +25,121 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim1(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
template <typename SrcT, typename DstT>
void copy_general_dim1(const array& src, array& dst) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[0];
src_idx += src.strides()[0];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim1(const array& src, array& dst) {
return copy_general_dim1<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim2(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
void copy_general_dim2(const array& src, array& dst) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[1];
src_idx += src.strides()[1];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim2(const array& src, array& dst) {
return copy_general_dim2<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim3(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
void copy_general_dim3(const array& src, array& dst) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) {
for (size_t k = 0; k < src.shape()[2]; ++k) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[2];
src_idx += src.strides()[2];
}
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim3(const array& src, array& dst) {
return copy_general_dim3<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim4(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
void copy_general_dim4(const array& src, array& dst) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
for (int ii = 0; ii < data_shape[3]; ++ii) {
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) {
for (size_t k = 0; k < src.shape()[2]; ++k) {
for (size_t ii = 0; ii < src.shape()[3]; ++ii) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += i_strides[3];
src_idx += src.strides()[3];
}
src_idx += i_strides[2] - i_strides[3] * data_shape[3];
src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3];
}
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
}
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
}
}
template <typename SrcT, typename DstT>
inline void copy_general_dim4(const array& src, array& dst) {
return copy_general_dim4<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
void copy_general(const array& src, array& dst) {
switch (src.ndim()) {
case 1:
copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
copy_general_dim1<SrcT, DstT>(src, dst);
return;
case 2:
copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
copy_general_dim2<SrcT, DstT>(src, dst);
return;
case 3:
copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
copy_general_dim3<SrcT, DstT>(src, dst);
return;
case 4:
copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
copy_general_dim4<SrcT, DstT>(src, dst);
return;
}
auto src_ptr = src.data<SrcT>() + i_offset;
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) {
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
size_t src_elem = elem_to_loc(i, src.shape(), src.strides());
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
}
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
return copy_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
inline void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
return copy_general<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
}
template <typename SrcT, typename DstT, typename stride_t, int D>
template <typename SrcT, typename DstT, int D>
inline void copy_general_general_dims(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
size_t offset_src,
size_t offset_dst) {
if constexpr (D > 1) {
int axis = src.ndim() - D;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
auto stride_src = src.strides()[axis];
auto stride_dst = dst.strides()[axis];
auto N = src.shape(axis);
for (int i = 0; i < N; i++) {
copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
i_offset += stride_src;
o_offset += stride_dst;
copy_general_general_dims<SrcT, DstT, D - 1>(
src, dst, offset_src, offset_dst);
offset_src += stride_src;
offset_dst += stride_dst;
}
} else {
int axis = src.ndim() - 1;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
DstT* dst_ptr = dst.data<DstT>() + o_offset;
auto stride_src = src.strides()[axis];
auto stride_dst = dst.strides()[axis];
auto N = src.shape(axis);
const SrcT* src_ptr = src.data<SrcT>() + offset_src;
DstT* dst_ptr = dst.data<DstT>() + offset_dst;
for (int i = 0; i < N; i++) {
*dst_ptr = static_cast<DstT>(*src_ptr);
src_ptr += stride_src;
@@ -223,56 +148,37 @@ inline void copy_general_general_dims(
}
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
template <typename SrcT, typename DstT>
void copy_general_general(const array& src, array& dst) {
switch (src.ndim()) {
case 1:
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
copy_general_general_dims<SrcT, DstT, 1>(src, dst, 0, 0);
return;
case 2:
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
copy_general_general_dims<SrcT, DstT, 2>(src, dst, 0, 0);
return;
case 3:
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
copy_general_general_dims<SrcT, DstT, 3>(src, dst, 0, 0);
return;
case 4:
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
copy_general_general_dims<SrcT, DstT, 4>(src, dst, 0, 0);
return;
case 5:
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
copy_general_general_dims<SrcT, DstT, 5>(src, dst, 0, 0);
return;
}
int size = std::accumulate(
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
src.shape().begin() - 5, src.shape().end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) {
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
size_t offset_src = elem_to_loc(i, src.shape(), src.strides());
size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides());
copy_general_general_dims<SrcT, DstT, 5>(src, dst, offset_src, offset_dst);
}
}
template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) {
return copy_general_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
}
template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
void copy(const array& src, array& dst, CopyType ctype) {
switch (ctype) {
case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst);
@@ -281,103 +187,54 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
copy_vector<SrcT, DstT>(src, dst);
return;
case CopyType::General:
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
copy_general<SrcT, DstT>(src, dst);
return;
case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
copy_general_general<SrcT, DstT>(src, dst);
}
}
template <typename SrcT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
template <typename SrcT>
void copy(const array& src, array& dst, CopyType ctype) {
switch (dst.dtype()) {
case bool_:
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, bool>(src, dst, ctype);
break;
case uint8:
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint8_t>(src, dst, ctype);
break;
case uint16:
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint16_t>(src, dst, ctype);
break;
case uint32:
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint32_t>(src, dst, ctype);
break;
case uint64:
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint64_t>(src, dst, ctype);
break;
case int8:
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int8_t>(src, dst, ctype);
break;
case int16:
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int16_t>(src, dst, ctype);
break;
case int32:
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int32_t>(src, dst, ctype);
break;
case int64:
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int64_t>(src, dst, ctype);
break;
case float16:
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, float16_t>(src, dst, ctype);
break;
case float32:
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, float>(src, dst, ctype);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, bfloat16_t>(src, dst, ctype);
break;
case complex64:
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
}
}
template <typename... Args>
inline void copy_inplace_dispatch(
const array& src,
array& dst,
CopyType ctype,
Args&&... args) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint8:
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint16:
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint32:
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case uint64:
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int8:
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int16:
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int32:
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case int64:
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float16:
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float32:
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
case complex64:
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, complex64_t>(src, dst, ctype);
break;
}
}
@@ -385,7 +242,47 @@ inline void copy_inplace_dispatch(
} // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) {
return copy_inplace_dispatch(src, dst, ctype);
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype);
break;
case uint8:
copy<uint8_t>(src, dst, ctype);
break;
case uint16:
copy<uint16_t>(src, dst, ctype);
break;
case uint32:
copy<uint32_t>(src, dst, ctype);
break;
case uint64:
copy<uint64_t>(src, dst, ctype);
break;
case int8:
copy<int8_t>(src, dst, ctype);
break;
case int16:
copy<int16_t>(src, dst, ctype);
break;
case int32:
copy<int32_t>(src, dst, ctype);
break;
case int64:
copy<int64_t>(src, dst, ctype);
break;
case float16:
copy<float16_t>(src, dst, ctype);
break;
case float32:
copy<float>(src, dst, ctype);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype);
break;
case complex64:
copy<complex64_t>(src, dst, ctype);
break;
}
}
void copy(const array& src, array& dst, CopyType ctype) {
@@ -415,62 +312,4 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype);
}
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
template <>
void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<int64_t>& i_strides,
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#pragma once
@@ -26,15 +26,4 @@ enum class CopyType {
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core

View File

@@ -34,7 +34,6 @@ DEFAULT(ArcCosh)
DEFAULT(ArcSin)
DEFAULT(ArcSinh)
DEFAULT(ArcTan)
DEFAULT(ArcTan2)
DEFAULT(ArcTanh)
DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
@@ -42,13 +41,9 @@ DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(BlockSparseQMM)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
@@ -56,13 +51,11 @@ DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(NumberOfElements)
DEFAULT(Remainder)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(Expm1)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Full)
@@ -100,7 +93,6 @@ DEFAULT(Sign)
DEFAULT(Sin)
DEFAULT(Sinh)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
@@ -108,11 +100,9 @@ DEFAULT(Square)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
DEFAULT(Subtract)
DEFAULT_MULTI(SVD)
DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
namespace {

View File

@@ -1,95 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
void inverse_impl(const array& a, array& inv) {
// Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see
// https://math.stackexchange.com/a/340234):
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output.
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
for (int i = 0; i < num_matrices; i++) {
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch =
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
}
void Inverse::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32.");
}
inverse_impl(inputs[0], output);
}
} // namespace mlx::core

View File

@@ -1,23 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME)
// This is to work around a change in the function signatures of lapack >= 3.9.1
// where functions taking char* also include a strlen argument, see a similar
// change in OpenCV:
// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57
#define MLX_LAPACK_FUNC(f) LAPACK_##f
#else
#define MLX_LAPACK_FUNC(f) f##_
#endif

View File

@@ -11,7 +11,7 @@ GCC=$2
SRCDIR=$3
CLANG=$4
if [ "$CLANG" = "TRUE" ]; then
if [ $CLANG = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM
#include <cmath>
#include <complex>

View File

@@ -1,280 +0,0 @@
// Copyright © 2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T>
inline void mask_matrix(
T* data,
const bool* mask,
int block_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
if (!do_mask) {
int loc_x = i * block_size;
int loc_y = j * block_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
int size_x = std::min(block_size, X - loc_x);
int size_y = std::min(block_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
}
}
}
}
}
}
} // namespace
void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto& out_mask = inputs[2];
auto check_transpose = [](const array& arr, bool do_copy) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(false, stx, arr_copy);
}
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(true, sty, arr_copy);
}
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
bool has_op_mask = inputs.size() > 3;
auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask);
auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
auto mask_array = [](const array& mask,
float* data,
int block_size,
int batch_idx,
int X,
int Y,
size_t X_data_str,
size_t Y_data_str) {
const bool* mask_ptr = mask.data<bool>() +
elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
return mask_matrix(
data,
mask_ptr,
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str);
};
for (int i = 0; i < (a.size() / (M * K)); ++i) {
// Adjust pointer
float* ai =
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
float* bi =
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides());
float* ci = out.data<float>() + M * N * i;
// Zero out blocks in a and b if needed
if (has_op_mask) {
auto& a_mask = inputs[3];
mask_array(
a_mask,
ai,
block_size_,
i,
M,
K,
a_transposed ? 1 : lda,
a_transposed ? lda : 1);
auto& b_mask = inputs[4];
mask_array(
b_mask,
bi,
block_size_,
i,
K,
N,
b_transposed ? 1 : ldb,
b_transposed ? ldb : 1);
}
// Do matmul
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0, // alpha
ai,
lda,
bi,
ldb,
0.0, // beta
ci,
out.shape(-1) // ldc
);
// Zero out blocks in out
mask_array(out_mask, ci, block_size_, i, M, N, N, 1);
}
}
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockSparseMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
// Get batch dims
auto batch_size_out = out.size() / (M * N);
size_t matrix_stride_out = M * N;
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
};
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < batch_size_out; i++) {
// Get index
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda,
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb,
0.0f, // beta
out.data<float>() + matrix_stride_out * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core

View File

@@ -161,13 +161,6 @@ struct ArcTan {
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return std::atan2(y, x);
};
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
@@ -209,12 +202,6 @@ struct Ceil {
};
};
struct Conjugate {
complex64_t operator()(complex64_t x) {
return std::conj(x);
}
};
struct Cos {
template <typename T>
T operator()(T x) {
@@ -254,13 +241,6 @@ struct Exp {
}
};
struct Expm1 {
template <typename T>
T operator()(T x) {
return expm1(x);
};
};
struct Floor {
template <typename T>
T operator()(T x) {
@@ -619,39 +599,4 @@ struct Select {
}
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
} // namespace mlx::core::detail

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
@@ -22,7 +22,7 @@ namespace mlx::core {
void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), unsignedinteger)) {
if (is_unsigned(in.dtype())) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
@@ -37,7 +37,7 @@ void Arange::eval(const std::vector<array>& inputs, array& out) {
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::ArcCos());
} else {
throw std::invalid_argument(
@@ -49,7 +49,7 @@ void ArcCos::eval(const std::vector<array>& inputs, array& out) {
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::ArcCosh());
} else {
throw std::invalid_argument(
@@ -61,7 +61,7 @@ void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::ArcSin());
} else {
throw std::invalid_argument(
@@ -73,7 +73,7 @@ void ArcSin::eval(const std::vector<array>& inputs, array& out) {
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::ArcSinh());
} else {
throw std::invalid_argument(
@@ -85,7 +85,7 @@ void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::ArcTan());
} else {
throw std::invalid_argument(
@@ -97,7 +97,7 @@ void ArcTan::eval(const std::vector<array>& inputs, array& out) {
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::ArcTanh());
} else {
throw std::invalid_argument(
@@ -113,10 +113,65 @@ void AsType::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, ctype);
}
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (!in.flags().row_contiguous) {
// Just ensuring that inputs[0] came from the ops which would ensure the
// input is row contiguous.
throw std::runtime_error(
"AsStrided must be used with row contiguous arrays only.");
}
// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
c *= shape_[j];
}
auto flags = in.flags();
// TODO: Compute the contiguous flag in a better way cause now we are
// unnecessarily strict.
flags.contiguous = row_contiguous || col_contiguous;
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
// There is no easy way to compute the actual data size so we use out.size().
// The contiguous flag will almost certainly not be set so no code should
// rely on data_size anyway.
size_t data_size = out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
std::vector<size_t> strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Ceil::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
if (not is_integral(in.dtype())) {
unary_fp(in, out, detail::Ceil());
} else {
// No-op integer types
@@ -148,21 +203,15 @@ void Concatenate::eval(const std::vector<array>& inputs, array& out) {
}
}
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_fp(in, out, detail::Conjugate());
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
}
out.copy_shared_buffer(inputs[0]);
}
void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Cos());
} else {
throw std::invalid_argument(
@@ -174,7 +223,7 @@ void Cos::eval(const std::vector<array>& inputs, array& out) {
void Cosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Cosh());
} else {
throw std::invalid_argument(
@@ -183,6 +232,25 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
}
}
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -226,7 +294,7 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
void Exp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Exp());
} else {
throw std::invalid_argument(
@@ -235,22 +303,10 @@ void Exp::eval(const std::vector<array>& inputs, array& out) {
}
}
void Expm1::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Expm1());
} else {
throw std::invalid_argument(
"[expm1] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
if (not is_integral(in.dtype())) {
unary_fp(in, out, detail::Floor());
} else {
// No-op integer types
@@ -276,7 +332,7 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log());
@@ -298,7 +354,7 @@ void Log::eval(const std::vector<array>& inputs, array& out) {
void Log1p::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Log1p());
} else {
throw std::invalid_argument(
@@ -415,20 +471,24 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
if (in.flags().row_contiguous) {
// For row contiguous reshapes:
// - Shallow copy the buffer
// - If reshaping into a vector (all singleton dimensions except one) it
// becomes col contiguous again.
auto flags = in.flags();
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
} else {
shared_buffer_reshape(in, out_strides, out);
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
}
}
void Round::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
if (not is_integral(in.dtype())) {
unary_fp(in, out, detail::Round());
} else {
// No-op integer types
@@ -439,7 +499,7 @@ void Round::eval(const std::vector<array>& inputs, array& out) {
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Sigmoid());
} else {
throw std::invalid_argument(
@@ -461,7 +521,7 @@ void Sign::eval(const std::vector<array>& inputs, array& out) {
void Sin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Sin());
} else {
throw std::invalid_argument(
@@ -473,7 +533,7 @@ void Sin::eval(const std::vector<array>& inputs, array& out) {
void Sinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Sinh());
} else {
throw std::invalid_argument(
@@ -488,66 +548,96 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ in,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General);
} else {
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out);
auto strides = in.strides();
auto flags = in.flags();
size_t data_offset = 0;
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
strides[i] *= strides_[i];
}
// Compute row/col contiguity
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || out.shape(i) == 1;
flags.row_contiguous &= strides[ri] == b_stride || out.shape(ri) == 1;
f_stride *= out.shape(i);
b_stride *= out.shape(ri);
if (strides[i] > 0) {
data_size *= out.shape(i);
}
}
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
void Split::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto& upd = inputs[1];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
auto compute_new_flags = [](const auto& shape,
const auto& strides,
size_t in_data_size,
auto flags) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in_data_size) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
return std::pair<decltype(flags), size_t>{flags, data_size};
};
std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < indices.size(); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
outputs[i].copy_shared_buffer(
in, in.strides(), new_flags, data_size, offset);
}
// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(out);
// Do copy
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd_strides,
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void Square::eval(const std::vector<array>& inputs, array& out) {
@@ -566,10 +656,15 @@ void Sqrt::eval(const std::vector<array>& inputs, array& out) {
}
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void Tan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Tan());
} else {
throw std::invalid_argument(
@@ -581,7 +676,7 @@ void Tan::eval(const std::vector<array>& inputs, array& out) {
void Tanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
if (is_floating_point(out.dtype())) {
unary_fp(in, out, detail::Tanh());
} else {
throw std::invalid_argument(
@@ -590,4 +685,38 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
}
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::vector<size_t> out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
}
// Conditions for {row/col}_contiguous
// - array must be contiguous (no gaps)
// - underlying buffer size should have the same size as the array
// - cumulative product of shapes is equal to the strides (we can ignore axes
// with size == 1)
// - in the forward direction (column contiguous)
// - in the reverse direction (row contiguous)
// - vectors are both row and col contiguous (hence if both row/col are
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
f_stride *= out.shape(i);
flags.row_contiguous &=
(out_strides[ri] == b_stride || out.shape(ri) == 1);
b_stride *= out.shape(ri);
}
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -192,7 +192,7 @@ void _qmm_dispatch_typed(
}
void _qmm_dispatch(
array& out,
array out,
const array& x,
const array& w,
const array& scales,
@@ -253,81 +253,6 @@ void _qmm_dispatch(
}
}
void _bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
const array& lhs_indices,
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
}
} // namespace
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
@@ -357,45 +282,4 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}
void BlockSparseQMM::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto ensure_row_contiguous_last_dims = [](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
}
} // namespace mlx::core

View File

@@ -6,6 +6,8 @@
namespace mlx::core {
namespace {
enum ReductionOpType {
// Self-explanatory. Read everything and produce 1 output.
ContiguousAllReduce,
@@ -36,21 +38,6 @@ enum ReductionOpType {
GeneralReduce
};
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
namespace {
// Helper for the ndimensional strided loop
// Should this be in utils?
inline void nd_loop(
@@ -123,6 +110,19 @@ struct DefaultContiguousReduce {
}
};
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@@ -0,0 +1,13 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast_primitives.h"
namespace mlx::core::fast {
void RoPE::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("NYI");
}
} // namespace mlx::core::fast

View File

@@ -222,7 +222,7 @@ void scan_dispatch(
}
case Scan::Min: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
auto init = (issubdtype(input.dtype(), floating))
auto init = (is_floating_point(input.dtype()))
? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
@@ -232,7 +232,7 @@ void scan_dispatch(
}
case Scan::Max: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
auto init = (issubdtype(input.dtype(), floating))
auto init = (is_floating_point(input.dtype()))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <cmath>
@@ -10,7 +10,7 @@ namespace mlx::core {
namespace {
template <typename T, typename AccT>
template <typename T>
void softmax(const array& in, array& out) {
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
@@ -22,36 +22,26 @@ void softmax(const array& in, array& out) {
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
// Find the maximum
current_in_ptr = in_ptr;
AccT maximum = *current_in_ptr;
T maximum = *current_in_ptr;
for (int j = 0; j < N; j++, current_in_ptr++) {
maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
: maximum;
maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum;
}
// Compute the normalizer and the exponentials
AccT normalizer = 0;
T normalizer = 0;
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
AccT expv = std::exp(*current_in_ptr - maximum);
T expv = std::exp(*current_in_ptr - maximum);
normalizer += expv;
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr = expv;
}
*current_out_ptr = expv;
}
normalizer = 1 / normalizer;
// Normalize
current_in_ptr = in_ptr;
current_out_ptr = out_ptr;
for (int j = 0; j < N; j++, current_out_ptr++) {
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
auto v = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(v * normalizer);
current_in_ptr++;
}
*current_out_ptr *= normalizer;
}
}
}
@@ -77,15 +67,11 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
}
};
array in = check_input(std::move(inputs[0]));
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
switch (in.dtype()) {
case bool_:
@@ -101,21 +87,13 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float, float>(in, out);
softmax<float>(in, out);
break;
case float16:
if (precise_) {
softmax<float16_t, float>(in, out);
} else {
softmax<float16_t, float16_t>(in, out);
}
softmax<float16_t>(in, out);
break;
case bfloat16:
if (precise_) {
softmax<bfloat16_t, float>(in, out);
} else {
softmax<bfloat16_t, bfloat16_t>(in, out);
}
softmax<bfloat16_t>(in, out);
break;
case complex64:
throw std::invalid_argument(

View File

@@ -1,147 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h"
#include "mlx/primitives.h"
namespace mlx::core {
void svd_impl(const array& a, array& u, array& s, array& vt) {
// Lapack uses the column-major convention. To avoid having to transpose
// the input and then transpose the outputs, we swap the indices/sizes of the
// matrices and take advantage of the following identity (see
// https://math.stackexchange.com/a/30077)
// A = UΣVᵀ
// Aᵀ = VΣUᵀ
// As a result some of the indices/sizes are swapped as noted above.
// Rows and cols of the original matrix in row-major order.
const int M = a.shape(-2);
const int N = a.shape(-1);
const int K = std::min(M, N);
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), float32, nullptr, {});
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
// Allocate outputs.
u.set_data(allocator::malloc_or_wait(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
static constexpr auto job_u = "V";
static constexpr auto job_vt = "V";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned.
int ns = 0;
float workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not used
// here but required by lapack).
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
static const int ignored_int = 0;
static const float ignored_float = 0;
int info;
// Compute workspace size.
MLX_LAPACK_FUNC(sgesvdx)
(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
MLX_LAPACK_FUNC(sgesvdx)
(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in.data<float>() + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s.data<float>() + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt.data<float>() + N * N * i,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u.data<float>() + M * M * i,
/* ldvt = */ &ldvt,
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
}
}
void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[SVD::eval] only supports float32.");
}
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#pragma once
@@ -8,12 +8,11 @@
namespace mlx::core {
template <typename stride_t>
inline stride_t elem_to_loc(
inline size_t elem_to_loc(
int elem,
const std::vector<int>& shape,
const std::vector<stride_t>& strides) {
stride_t loc = 0;
const std::vector<size_t>& strides) {
size_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i];
@@ -29,93 +28,4 @@ inline size_t elem_to_loc(int elem, const array& a) {
return elem_to_loc(elem, a.shape(), a.strides());
}
// Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}.
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
template <typename stride_t>
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<stride_t>> strides) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
to_collapse.push_back(0);
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
for (const std::vector<stride_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<stride_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<stride_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
std::vector<std::vector<size_t>> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
return collapse_contiguous_dims(xs[0].shape(), strides);
}
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
inline auto collapse_contiguous_dims(Arrays&&... xs) {
return collapse_contiguous_dims(
std::vector<array>{std::forward<Arrays>(xs)...});
}
template <typename stride_t>
inline auto check_contiguity(
const std::vector<int>& shape,
const std::vector<stride_t>& strides) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
bool is_row_contiguous = true;
bool is_col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
}
} // namespace mlx::core

View File

@@ -5,16 +5,10 @@ add_custom_command(
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR}
"-D${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/compiled_preamble.h
kernels/unary.h
kernels/binary.h
kernels/bf16.h
kernels/erf.h
kernels/expm1f.h
kernels/utils.h
kernels/bf16_math.h
)
add_custom_target(
@@ -32,7 +26,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
@@ -40,7 +33,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp

View File

@@ -1,7 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include <mach/vm_page_size.h>
#include <unistd.h>
@@ -140,15 +139,10 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]);
block_limit_ =
std::min(1.5 * device_->recommendedMaxWorkingSetSize(), 0.95 * memsize);
gc_limit_ = std::min(
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
block_limit_);
max_pool_size_ = block_limit_;
}
buffer_cache_(device_),
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()),
max_pool_size_(block_limit_) {}
size_t MetalAllocator::set_cache_limit(size_t limit) {
std::swap(limit, max_pool_size_);
@@ -170,15 +164,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{nullptr};
}
// More helpful message if maximum buffer length is exceeded
if (size > device_->maxBufferLength()) {
std::ostringstream msg;
msg << "Attempting to allocate " << size << " bytes which is greater than"
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
<< " bytes.";
throw std::runtime_error(msg.str());
}
// Align up memory
if (size > vm_page_size) {
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
@@ -223,11 +208,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::clear_cache() {
std::unique_lock lk(mutex_);
buffer_cache_.clear();
}
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
std::unique_lock lk(mutex_);
@@ -258,15 +238,9 @@ size_t get_active_memory() {
size_t get_peak_memory() {
return allocator().get_peak_memory();
}
void reset_peak_memory() {
allocator().reset_peak_memory();
}
size_t get_cache_memory() {
return allocator().get_cache_memory();
}
void clear_cache() {
return allocator().clear_cache();
}
} // namespace metal

View File

@@ -26,7 +26,6 @@ class BufferCache {
size_t cache_size() {
return pool_size_;
}
void clear();
private:
struct BufferHolder {
@@ -38,6 +37,7 @@ class BufferCache {
MTL::Buffer* buf;
};
void clear();
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
@@ -62,16 +62,11 @@ class MetalAllocator : public allocator::Allocator {
size_t get_peak_memory() {
return peak_memory_;
};
void reset_peak_memory() {
std::unique_lock lk(mutex_);
peak_memory_ = 0;
};
size_t get_cache_memory() {
return buffer_cache_.cache_size();
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
void clear_cache();
private:
MTL::Device* device_;

View File

@@ -3,7 +3,6 @@
#include <sstream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/compiled_preamble.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
@@ -229,7 +228,14 @@ void Compiled::eval_gpu(
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
bool contiguous = compiled_check_contiguity(inputs, output_shape);
bool contiguous = true;
for (auto& x : inputs) {
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
!is_scalar(x)) {
contiguous = false;
break;
}
}
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
@@ -289,7 +295,7 @@ void Compiled::eval_gpu(
}
}
auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Put the inputs in
@@ -300,7 +306,7 @@ void Compiled::eval_gpu(
continue;
}
auto& x = inputs[i];
compute_encoder.set_input_array(x, cnt++);
set_array_buffer(compute_encoder, x, cnt++);
if (!contiguous && !is_scalar(x)) {
compute_encoder->setBytes(
strides[stride_idx].data(),
@@ -310,12 +316,30 @@ void Compiled::eval_gpu(
}
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, true);
// Allocate space for the outputs possibly with input donation
{
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].move_shared_buffer(in);
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
// Put the outputs in
for (auto& x : outputs) {
compute_encoder.set_output_array(x, cnt++);
set_array_buffer(compute_encoder, x, cnt++);
}
// Put the output shape and strides in
@@ -336,7 +360,7 @@ void Compiled::eval_gpu(
MTL::Size grid_dims(nthreads, 1, 1);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
@@ -347,7 +371,7 @@ void Compiled::eval_gpu(
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -28,12 +28,10 @@ void explicit_gemm_conv_ND_gpu(
const array& wt,
array out,
const MLXConvParams<N>& conv_params) {
// Get gemm shapes
int implicit_M = out.size() / conv_params.O;
int implicit_K = wt.size() / conv_params.O;
int implicit_N = conv_params.O;
// Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K};
std::vector<int> unfolded_shape = {
static_cast<int>(out.size() / conv_params.O),
static_cast<int>(wt.size() / conv_params.O)};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
@@ -41,12 +39,12 @@ void explicit_gemm_conv_ND_gpu(
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
@@ -59,120 +57,27 @@ void explicit_gemm_conv_ND_gpu(
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N};
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
auto wt_flags = wt.flags();
wt_flags.row_contiguous = false;
wt_flags.col_contiguous = true;
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_reshaped};
std::vector<array> copies;
return steel_matmul(
s,
d,
/*a = */ in_unfolded,
/*b = */ wt_reshaped,
/*b = */ wt,
/*c = */ out,
/*M = */ implicit_M,
/*N = */ implicit_N,
/*K = */ implicit_K,
/*M = */ unfolded_shape[0],
/*N = */ conv_params.O,
/*K = */ unfolded_shape[1],
/*batch_size_out = */ 1,
/*a_cols = */ implicit_K,
/*b_cols = */ implicit_K,
/*a_cols = */ unfolded_shape[1],
/*b_cols = */ unfolded_shape[1],
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
}
template <int N>
void explicit_gemm_conv_group_ND_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<N>& conv_params) {
const int groups = conv_params.groups;
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
// Get gemm shapes
const int implicit_M = out.size() / conv_params.O;
const int implicit_K = wt.size() / conv_params.O;
const int implicit_N = O_per_group;
int kernel_size = 1;
for (int i = 0; i < N; ++i) {
kernel_size *= conv_params.wS[i];
}
// Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
<< N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups.
array wt_view(
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
wt_view.copy_shared_buffer(
wt,
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
wt.flags(),
wt.size());
// Materialize
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
return steel_matmul_conv_groups(
s,
d,
/*a = */ in_unfolded,
/*b = */ wt_transpose,
/*c = */ out,
/*M = */ implicit_M,
/*N = */ implicit_N,
/*K = */ implicit_K,
/*a_cols = */ implicit_K * groups,
/*b_cols = */ implicit_K,
/*out_cols = */ implicit_N * groups,
/*a_transposed = */ false,
/*b_transposed = */ true,
/* groups = */ groups,
/*copies = */ copies);
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
@@ -183,7 +88,6 @@ void conv_1D_gpu(
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
int groups,
bool flip) {
// Make conv params
MLXConvParams<1> conv_params{
@@ -203,15 +107,11 @@ void conv_1D_gpu(
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ groups,
/* const int groups = */ 1,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
void slow_conv_2D_gpu(
@@ -229,7 +129,7 @@ void slow_conv_2D_gpu(
<< "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -242,12 +142,12 @@ void slow_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(bm, bn, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
@@ -330,7 +230,7 @@ void implicit_gemm_conv_2D_gpu(
<< "_filter_" << (small_filter ? 's' : 'l');
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -343,16 +243,16 @@ void implicit_gemm_conv_2D_gpu(
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
// Encode arrays
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
// Launch kernel
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_general_gpu(
@@ -483,7 +383,7 @@ void implicit_gemm_conv_2D_general_gpu(
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -497,9 +397,9 @@ void implicit_gemm_conv_2D_general_gpu(
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
// Encode arrays
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
@@ -512,7 +412,7 @@ void implicit_gemm_conv_2D_general_gpu(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
@@ -600,12 +500,12 @@ void winograd_conv_2D_gpu(
std::ostringstream kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(wt, 0);
compute_encoder.set_output_array(filt_wg, 1);
set_array_buffer(compute_encoder, wt, 0);
set_array_buffer(compute_encoder, filt_wg, 1);
compute_encoder->setBytes(&C_c, sizeof(int), 2);
compute_encoder->setBytes(&O_c, sizeof(int), 3);
@@ -613,7 +513,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Do input transform
@@ -628,12 +528,12 @@ void winograd_conv_2D_gpu(
std::ostringstream kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_output_array(inp_wg, 1);
set_array_buffer(compute_encoder, in_padded, 0);
set_array_buffer(compute_encoder, inp_wg, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
@@ -641,7 +541,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Do batched gemm
@@ -676,12 +576,12 @@ void winograd_conv_2D_gpu(
std::ostringstream kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1);
set_array_buffer(compute_encoder, out_wg, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
@@ -689,7 +589,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
}
@@ -759,56 +659,6 @@ void conv_2D_gpu(
}
}
void conv_3D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip,
std::vector<array>& copies) {
// Make conv params
MLXConvParams<3> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(4),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
/* const int kdil[NDIM] = */
{wt_dilation[0], wt_dilation[1], wt_dilation[2]},
/* const int idil[NDIM] = */
{in_dilation[0], in_dilation[1], in_dilation[2]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0],
in.strides()[1],
in.strides()[2],
in.strides()[3],
in.strides()[4]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0],
wt.strides()[1],
wt.strides()[2],
wt.strides()[3],
wt.strides()[4]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0],
out.strides()[1],
out.strides()[2],
out.strides()[3],
out.strides()[4]},
/* const int groups = */ 1,
/* const bool flip = */ flip,
};
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -833,23 +683,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
wt = arr_copy;
}
// 3D conv
if (out.ndim() == 5) {
conv_3D_gpu(
s,
d,
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_,
copies);
}
// 2D conv
else if (out.ndim() == 4) {
if (out.ndim() == 4) {
conv_2D_gpu(
s,
d,
@@ -875,7 +710,6 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,
flip_);
}
// Throw error

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <sstream>
@@ -12,15 +12,8 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
if (in.dtype() == out.dtype()) {
return;
}
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
@@ -44,22 +37,15 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
template <typename stride_t>
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int>& data_shape,
const std::vector<stride_t>& strides_in_pre,
const std::vector<stride_t>& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
const Stream& s) {
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(
data_shape, std::vector{strides_in_pre, strides_out_pre});
auto& strides_in_ = strides[0];
auto& strides_out_ = strides[1];
auto [shape, strides] = collapse_contiguous_dims(in, out);
auto& strides_in = strides[0];
auto& strides_out = strides[1];
auto& d = metal::device(s.device);
std::ostringstream kname;
@@ -83,50 +69,45 @@ void copy_gpu_inplace(
kname << "_" << shape.size();
}
auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr;
inp_offset *= size_of(in.dtype());
out_offset *= size_of(out.dtype());
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
compute_encoder.set_output_array(out, 1, out_offset);
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
int ndim = shape.size();
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
size_t ndim = shape.size();
if (ndim > 3) {
set_vector_bytes(compute_encoder, shape, ndim, 2);
}
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
if (ctype == CopyType::GeneralGeneral) {
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3);
if (ctype == CopyType::GeneralGeneral) {
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4);
}
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
if (ctype == CopyType::GeneralGeneral) {
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3);
}
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5);
compute_encoder->setBytes(
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
}
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
int rest = data_size / (dim0 * dim1);
int rest = in.size() / (dim0 * dim1);
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
@@ -135,29 +116,8 @@ void copy_gpu_inplace(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int64_t>& istride,
int64_t ioffset,
CopyType ctype,
const Stream& s) {
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
return copy_gpu_inplace(
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#pragma once
@@ -7,34 +7,12 @@
namespace mlx::core {
// Generic copy inplace
template <typename stride_t>
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype,
const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype);
void copy_gpu_inplace(
const array& src,
array& out,
CopyType ctype,
const Stream& s);
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int64_t>& istride,
int64_t ioffset,
CopyType ctype,
const Stream& s);
} // namespace mlx::core

View File

@@ -1,21 +1,17 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023-24 Apple Inc.
#include <dlfcn.h>
#include <cstdlib>
#include <filesystem>
#include <sstream>
#include <sys/sysctl.h>
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
namespace fs = std::filesystem;
@@ -24,18 +20,9 @@ namespace mlx::core::metal {
namespace {
// TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr const char* default_mtllib_path = METAL_PATH;
constexpr auto get_metal_version() {
#if defined METAL_3_1
return MTL::LanguageVersion3_1;
#else
return MTL::LanguageVersion3_0;
#endif
}
static constexpr const char* default_mtllib_path = METAL_PATH;
auto load_device() {
auto devices = MTL::CopyAllDevices();
@@ -46,6 +33,7 @@ auto load_device() {
}
return device;
}
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
MTL::Device* device,
const char* path) {
@@ -124,33 +112,6 @@ MTL::Library* load_library(
} // namespace
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreadgroups(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreads(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::maybe_split() {
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
enc->endEncoding();
enc->release();
num_dispatches = 0;
outputs.clear();
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
}
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
@@ -165,6 +126,9 @@ Device::~Device() {
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& e : encoder_map_) {
e.second->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
@@ -181,7 +145,6 @@ void Device::new_queue(int index) {
// We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_);
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
debug_set_stream_queue_label(q, index);
if (!q) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");
@@ -201,26 +164,27 @@ void Device::increment_command_buffer_ops(int index) {
MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index);
if (bit == buffer_map_.end()) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
bit = buffer_map_.insert({index, {0, cb}}).first;
MTL::CommandBuffer* Device::new_command_buffer(int index) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
return bit->second.second;
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
return buffer_map_.insert({index, {0, cb}}).first->second.second;
}
void Device::commit_command_buffer(int index) {
@@ -231,17 +195,24 @@ void Device::commit_command_buffer(int index) {
}
void Device::end_encoding(int index) {
encoder_map_.erase(index);
auto eit = encoder_map_.find(index);
if (eit != encoder_map_.end()) {
eit->second->endEncoding();
eit->second->release();
encoder_map_.erase(eit);
}
}
CommandEncoder& Device::get_command_encoder(int index) {
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index);
eit =
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
auto compute_encoder = cb->computeCommandEncoder();
// Increment ref count so the buffer is not garbage collected
compute_encoder->retain();
eit = encoder_map_.insert({index, compute_encoder}).first;
}
return *(eit->second);
return eit->second;
}
void Device::register_library(
@@ -283,17 +254,13 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
NS::Error* error = nullptr;
auto options = MTL::CompileOptions::alloc()->init();
options->setFastMathEnabled(false);
options->setLanguageVersion(get_metal_version());
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
options->release();
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to build metal library from source" << "\n";
msg << "[metal::Device] Unable to load build metal library from source"
<< "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
@@ -312,7 +279,8 @@ MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
msg << "[metal::Device] Unable to load build stitched metal library"
<< "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
@@ -370,6 +338,7 @@ MTL::Function* Device::get_function_(
}
mtl_func_consts->release();
desc->release();
return mtl_function;
}
@@ -538,13 +507,11 @@ MTL::ComputePipelineState* Device::get_kernel(
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
kernel_map_.insert({kname, kernel});
return kernel;
}
@@ -571,12 +538,11 @@ Device& device(mlx::core::Device) {
return metal_device;
}
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
std::shared_ptr<void> new_scoped_memory_pool() {
auto dtor = [](void* ptr) {
static_cast<NS::AutoreleasePool*>(ptr)->release();
};
return std::unique_ptr<void, std::function<void(void*)>>(
NS::AutoreleasePool::alloc()->init(), dtor);
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
}
void new_stream(Stream stream) {
@@ -585,23 +551,4 @@ void new_stream(Stream stream) {
}
}
std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info() {
auto raw_device = device(default_device()).mtl_device();
auto arch = std::string(raw_device->architecture()->name()->utf8String());
int mib[] = {CTL_HW, HW_MEMSIZE};
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctl(mib, 2, &memsize, &length, NULL, 0);
return {
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize}};
}
} // namespace mlx::core::metal

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023-24 Apple Inc.
#pragma once
@@ -7,12 +7,10 @@
#include <mutex>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <dlfcn.h>
#include <filesystem>
#include "mlx/array.h"
#include "mlx/device.h"
namespace fs = std::filesystem;
@@ -36,84 +34,6 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
struct CommandEncoder {
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
};
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
enc.concurrent = true;
}
~ConcurrentContext() {
enc.concurrent = false;
enc.outputs.insert(
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
enc.concurrent_outputs.clear();
}
private:
CommandEncoder& enc;
};
MTL::ComputeCommandEncoder* operator->() {
return enc;
}
void set_input_array(const array& a, int idx, int offset = 0) {
auto r_buf =
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
void set_output_array(array& a, int idx, int offset = 0) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
} else {
outputs.insert(buf);
}
}
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}
~CommandEncoder() {
enc->endEncoding();
enc->release();
}
private:
void maybe_split();
int num_dispatches{0};
MTL::CommandBuffer* cbuf;
MTL::ComputeCommandEncoder* enc;
bool concurrent{false};
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs;
};
class Device {
public:
Device();
@@ -126,11 +46,12 @@ class Device {
};
void new_queue(int index);
MTL::CommandBuffer* new_command_buffer(int index);
MTL::CommandBuffer* get_command_buffer(int index);
int get_command_buffer_ops(int index);
void increment_command_buffer_ops(int index);
void commit_command_buffer(int index);
CommandEncoder& get_command_encoder(int index);
MTL::ComputeCommandEncoder* get_command_encoder(int index);
void end_encoding(int index);
void register_library(
@@ -211,7 +132,7 @@ class Device {
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
std::unordered_map<int32_t, MTL::ComputeCommandEncoder*> encoder_map_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::unordered_map<std::string, MTL::Library*> library_map_;
std::mutex mtx_;

View File

@@ -1,30 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/event.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h"
namespace mlx::core {
Event::Event(const Stream& stream) : stream_(stream) {
auto dtor = [](void* ptr) {
auto p = metal::new_scoped_memory_pool();
static_cast<MTL::SharedEvent*>(ptr)->release();
};
auto p = metal::new_scoped_memory_pool();
event_ = std::shared_ptr<void>(
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
}
void Event::wait() {
if (!static_cast<MTL::SharedEvent*>(raw_event().get())
->waitUntilSignaledValue(value(), -1)) {
throw std::runtime_error("[Event::wait] Timed out");
}
}
void Event::signal() {
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
}
} // namespace mlx::core

View File

@@ -1,106 +1,12 @@
// Copyright © 2023 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/mlx.h"
#include "mlx/primitives.h"
namespace mlx::core {
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
in.dtype() != complex64 || out.dtype() != complex64) {
// Could also fallback to CPU implementation here.
throw std::runtime_error(
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
}
size_t n = in.shape(axes_[0]);
if (!is_power_of_2(n) || n > 2048 || n < 4) {
throw std::runtime_error(
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
}
// Make sure that the array is contiguous and has stride 1 in the FFT dim
std::vector<array> copies;
auto check_input = [this, &copies, &s](const array& x) {
// TODO: Pass the strides to the kernel so
// we can avoid the copy when x is not contiguous.
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
x.flags().col_contiguous;
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
std::vector<size_t> strides;
size_t cur_stride = x.shape(axes_[0]);
for (int axis = 0; axis < x.ndim(); axis++) {
if (axis == axes_[0]) {
strides.push_back(1);
} else {
strides.push_back(cur_stride);
cur_stride *= x.shape(axis);
}
}
auto flags = x.flags();
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
f_stride *= x.shape(i);
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
b_stride *= x.shape(ri);
}
// This is probably over-conservative
flags.contiguous = false;
x_copy.set_data(
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
copies.push_back(x_copy);
return x_copy;
}
};
const array& in_contiguous = check_input(inputs[0]);
// TODO: allow donation here
out.set_data(
allocator::malloc_or_wait(out.nbytes()),
in_contiguous.data_size(),
in_contiguous.strides(),
in_contiguous.flags());
// We use n / 4 threads by default since radix-4
// is the largest single threaded radix butterfly
// we currently implement.
size_t m = n / 4;
size_t batch = in.size() / in.shape(axes_[0]);
auto& compute_encoder = d.get_command_encoder(s.index);
{
std::ostringstream kname;
kname << "fft_" << n;
auto kernel = d.get_kernel(kname.str());
bool donated = in.data_shared_ptr() == nullptr;
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1);
auto group_dims = MTL::Size(1, m, 1);
auto grid_dims = MTL::Size(batch, m, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
throw std::runtime_error("[FFT] NYI for Metal backend.");
}
} // namespace mlx::core

View File

@@ -16,7 +16,7 @@ namespace mlx::core {
namespace {
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
} // namespace
@@ -49,7 +49,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << "_" << idx_ndim;
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -81,8 +81,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
}
// Set all the buffers
compute_encoder.set_input_array(src, 0);
compute_encoder.set_output_array(out, 1);
set_array_buffer(compute_encoder, src, 0);
set_array_buffer(compute_encoder, out, 1);
// Set source info
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
@@ -103,11 +103,11 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i);
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Launch grid
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -183,7 +183,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
kname << "_" << nidx;
auto& compute_encoder = d.get_command_encoder(s.index);
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto& upd = inputs.back();
@@ -192,8 +192,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setComputePipelineState(kernel);
// Set all the buffers
compute_encoder.set_input_array(upd, 1);
compute_encoder.set_output_array(out, 2);
set_array_buffer(compute_encoder, upd, 1);
set_array_buffer(compute_encoder, out, 2);
// Set update info
uint upd_ndim = upd.ndim();
@@ -201,22 +201,25 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i);
}
if (index_nd1_specialization) {
bool upd_col_contiguous = upd.flags().col_contiguous;
compute_encoder->setBytes(
out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i);
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Collect all idx shapes and strides into one place
@@ -280,13 +283,13 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i);
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -7,7 +7,6 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
@@ -21,12 +20,9 @@ set(
"binary_two"
"conv"
"copy"
"fft"
"gemv"
"quantized"
"random"
"rms_norm"
"layer_norm"
"rope"
"scan"
"scaled_dot_product_attention"
@@ -39,17 +35,11 @@ set(
)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only
-frecord-sources)
endif()
add_custom_command(
COMMAND xcrun -sdk macosx metal
${METAL_FLAGS}
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
COMMAND xcrun -sdk macosx metal -Wall -Wextra
-fno-fast-math
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
-o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS}
OUTPUT ${TARGET}.air

View File

@@ -11,14 +11,14 @@ template <typename T>
out[index] = start + index * step;
}
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] \
[[kernel]] void arange<type>( \
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
// clang-format off
instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t)
instantiate_arange(uint32, uint32_t)
@@ -29,4 +29,4 @@ instantiate_arange(int32, int32_t)
instantiate_arange(int64, int64_t)
instantiate_arange(float16, half)
instantiate_arange(float32, float)
instantiate_arange(bfloat16, bfloat16_t) // clang-format on
instantiate_arange(bfloat16, bfloat16_t)

Some files were not shown because too many files have changed in this diff Show More