Compare commits

..

2 Commits

Author SHA1 Message Date
Angelos Katharopoulos
127de8821e Fix the sig_handler check 2025-03-07 17:31:06 -08:00
Awni Hannun
3ad9031a7f fences must exit 2025-03-07 09:28:33 -08:00
279 changed files with 5840 additions and 12910 deletions

View File

@@ -24,8 +24,8 @@ jobs:
type: boolean
default: false
macos:
xcode: "16.2.0"
resource_class: m2pro.medium
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
@@ -89,7 +89,6 @@ jobs:
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run:
name: Install Python package
command: |
@@ -109,8 +108,6 @@ jobs:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build CPP only
command: |
@@ -125,15 +122,10 @@ jobs:
parameters:
xcode_version:
type: string
default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
default: "15.2.0"
macos:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
@@ -154,9 +146,7 @@ jobs:
name: Install Python package
command: |
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
- run:
name: Generate package stubs
command: |
@@ -219,18 +209,13 @@ jobs:
default: "3.9"
xcode_version:
type: string
default: "16.2.0"
default: "15.2.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
@@ -251,7 +236,7 @@ jobs:
name: Install Python package
command: |
source env/bin/activate
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v
- run:
@@ -346,7 +331,7 @@ workflows:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test
- build_documentation
@@ -366,70 +351,8 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation:
filters:
tags:
@@ -452,7 +375,7 @@ workflows:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "14.0"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
@@ -465,54 +388,7 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
xcode_version: ["15.0.0", "15.2.0"]
weekly_build:
when:
and:
@@ -523,70 +399,8 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release:
when:
and:

1
.gitignore vendored
View File

@@ -36,7 +36,6 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
uv.lock
# vim
*.swp

View File

@@ -9,7 +9,6 @@ if(NOT MLX_VERSION)
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
set(_patch ${CMAKE_MATCH_1})
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
set(MLX_VERSION ${MLX_PROJECT_VERSION})
else()
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
${MLX_VERSION})
@@ -42,6 +41,8 @@ option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
# --------------------- Processor tests -------------------------
message(
STATUS
@@ -76,6 +77,7 @@ include(FetchContent)
cmake_policy(SET CMP0135 NEW)
add_library(mlx)
set_target_properties(mlx PROPERTIES COMPILE_WARNING_AS_ERROR ON)
if(MLX_BUILD_METAL)
set(METAL_LIB "-framework Metal")
@@ -212,6 +214,24 @@ else()
set(MLX_BUILD_ACCELERATE OFF)
endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
message(STATUS "Downloading json")
FetchContent_Declare(
json

View File

@@ -5,26 +5,26 @@ possible.
## Pull Requests
1. Fork and submit pull requests to the repo.
1. Fork and submit pull requests to the repo.
2. If you've added code that should be tested, add tests.
3. If a change is likely to impact efficiency, run some of the benchmarks before
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
4. If you've changed APIs, update the documentation.
5. Every PR should have passing tests and at least one review.
5. Every PR should have passing tests and at least one review.
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code.
You can also run the formatters manually as follows:
```shell
clang-format -i file.cpp
```
```shell
black file.py
```
```
clang-format -i file.cpp
```
```
black file.py
```
or run `pre-commit run --all-files` to check all files in the repo.
## Issues

View File

@@ -1,6 +1,4 @@
include CMakeLists.txt
include mlx.pc.in
recursive-include mlx/ *
include cmake/*
include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from time import time
import mlx.core as mx
import torch

View File

@@ -1,74 +0,0 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@@ -1,84 +0,0 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@@ -28,34 +28,11 @@ def bench(f, *args):
return (e - s) * 1e-9
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
np_dtype = getattr(np, dtype)
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
scale = 1.0 / math.sqrt(D)
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if mask is not None:
if mask == "additive":
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
mask = mx.array(mask_np)
elif mask == "bool":
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
mask = mx.array(mask_np)
return q_mx, k_mx, v_mx, scale, mask
def mlx_sdpa_fused_inner(q, k, v, scale):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
@@ -64,7 +41,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
B = q.shape[0]
L = q.shape[2]
kL = k.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
@@ -72,27 +48,10 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if mask is not None:
if mask == "causal":
q_offset = max(0, kL - L)
q_indices = mx.arange(q_offset, q_offset + L)
k_indices = mx.arange(kL)
mask = q_indices[:, None] >= k_indices[None]
if n_repeats > 1 and mask.ndim >= 3:
if mask.shape[-3] == 1:
mask = mx.expand_dims(mask, -3)
else:
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
if mask.dtype == mx.bool_:
scores = mx.where(mask, scores, -np.float32(np.inf))
else:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
else:
scores = mx.softmax(scores, axis=-1)
out = scores @ v
if n_repeats > 1:
@@ -101,55 +60,74 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
return out
def mlx_fused_attn(q, k, v, scale, mask):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
if transpose:
q_t = mx.transpose(q, (0, 2, 1, 3))
k_t = mx.transpose(k, (0, 2, 1, 3))
v_t = mx.transpose(v, (0, 2, 1, 3))
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
return mx.transpose(o_t, (0, 2, 1, 3))
else:
return f(q, k, v, scale=scale, mask=mask)
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
def mlx_spda_unfused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
):
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
def mlx_spda_fused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
)
time_mlx_unfused = bench(
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
time_mlx_fused = bench(
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
o_mlx_unfused = do_attention(
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
scale = math.sqrt(1.0 / head_dim)
atol = 1e-5 if dtype == "float32" else 2e-4
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
@@ -173,51 +151,39 @@ if __name__ == "__main__":
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 8),
( 1, 2048, 2048, 64, 32, 8),
( 1, 4096, 4096, 64, 32, 8),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 8),
( 1, 2048, 2048, 80, 32, 8),
( 1, 4096, 4096, 80, 32, 8),
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 8),
( 1, 2048, 2048, 128, 32, 8),
( 1, 4096, 4096, 128, 32, 8),
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
masks = [None, "bool", "causal"]
print(
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
)
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
for mask_in in masks:
time_mlx_fused, time_mlx_unfused = bench_shape(
B,
qsl,
ksl,
head_dim,
n_q_heads,
n_kv_heads,
dtype,
transpose,
mask_in,
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)
np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = NO
GENERATE_HTML = YES
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES

View File

@@ -93,9 +93,9 @@ Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create output arrays given input arrays. Further, a
defines how to create outputs arrays given a input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
more concrete:
.. code-block:: C++
@@ -128,7 +128,7 @@ more concrete:
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const array& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
@@ -247,7 +247,9 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
float alpha_,
float beta_,
mx::Stream stream) {
out.set_data(mx::allocator::malloc(out.nbytes()));
// Allocate the output with `malloc_or_wait` which synchronously allocates
// memory, potentially waiting if the system is under memory pressure
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
@@ -391,7 +393,7 @@ below.
auto& d = metal::device(s.device);
// Allocate output memory
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel
std::ostringstream kname;
@@ -469,7 +471,7 @@ one we just defined:
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can be built with ops
// The jvp transform on the primitive can built with ops
// that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the
@@ -481,7 +483,7 @@ one we just defined:
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
}
// If argnums = {0, 1}, we take contributions from both
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
@@ -735,7 +737,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c is correct: {mx.all(c == 6.0).item()}")
print(f"c correct: {mx.all(c == 6.0).item()}")
Output:
@@ -743,7 +745,7 @@ Output:
c shape: [3, 4]
c dtype: float32
c is correct: True
c correctness: True
Results
^^^^^^^

View File

@@ -70,7 +70,6 @@ are the CPU and GPU.
python/fft
python/linalg
python/metal
python/memory_management
python/nn
python/optimizers
python/distributed

View File

@@ -38,7 +38,6 @@ Array
array.log10
array.log1p
array.log2
array.logcumsumexp
array.logsumexp
array.max
array.mean

View File

@@ -20,5 +20,3 @@ FFT
irfft2
rfftn
irfftn
fftshift
ifftshift

View File

@@ -20,6 +20,5 @@ Linear Algebra
eigh
lu
lu_factor
pinv
solve
solve_triangular

View File

@@ -1,16 +0,0 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@@ -8,5 +8,13 @@ Metal
is_available
device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture
stop_capture

View File

@@ -36,12 +36,10 @@ Operations
bitwise_or
bitwise_xor
block_masked_mm
broadcast_arrays
broadcast_to
ceil
clip
concatenate
contiguous
conj
conjugate
convolve
@@ -103,7 +101,6 @@ Operations
log10
log1p
logaddexp
logcumsumexp
logical_not
logical_and
logical_or

View File

@@ -18,4 +18,3 @@ Common Optimizers
AdamW
Adamax
Lion
MultiOptimizer

View File

@@ -9,7 +9,6 @@ Transforms
:toctree: _autosummary
eval
async_eval
compile
custom_function
disable_compile

View File

@@ -72,7 +72,9 @@ void axpby_impl(
float alpha_,
float beta_,
mx::Stream stream) {
out.set_data(mx::allocator::malloc(out.nbytes()));
// Allocate the output with `malloc_or_wait` which synchronously allocates
// memory, potentially waiting if the system is under memory pressure
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream);
@@ -158,12 +160,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization
if (contiguous_kernel) {
out.set_data(
mx::allocator::malloc(x.data_size() * out.itemsize()),
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
x.data_size(),
x.strides(),
x.flags());
} else {
out.set_data(mx::allocator::malloc(out.nbytes()));
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
}
// Resolve name of kernel (corresponds to axpby.metal)

View File

@@ -5,7 +5,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
@@ -18,13 +17,9 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/version.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
# Define MLX_VERSION only in the version.cpp file.
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
if(MSVC)
# Disable some MSVC warnings to speed up compilation.
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
@@ -47,10 +42,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
endif()

View File

@@ -4,11 +4,12 @@
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
auto buffer = allocator().malloc(size, /* allow_swap */ true);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
@@ -21,4 +22,45 @@ void free(Buffer buffer) {
allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size, bool) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
Buffer malloc_or_wait(size_t size) {
auto buffer = allocator().malloc(size);
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
buffer = allocator().malloc(size);
}
// Try swapping if needed
if (size && !buffer.ptr()) {
buffer = allocator().malloc(size, /* allow_swap = */ true);
}
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace mlx::core::allocator

View File

@@ -32,10 +32,14 @@ Buffer malloc(size_t size);
void free(Buffer buffer);
// Wait for running tasks to finish and free up memory
// if allocation fails
Buffer malloc_or_wait(size_t size);
class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size) = 0;
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
@@ -49,4 +53,16 @@ class Allocator {
Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator

View File

@@ -56,18 +56,6 @@ std::vector<array> array::make_arrays(
return outputs;
}
array array::unsafe_weak_copy(const array& other) {
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
cpy.set_data(
other.buffer(),
other.data_size(),
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
return cpy;
}
array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>(
Shape{static_cast<ShapeElem>(data.size())},

View File

@@ -199,13 +199,6 @@ class array {
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs);
/**
* Get a new array that refers to the same data as the input but with a
* non-owning pointer to it. Note the array is detached from the graph and has
* no inputs, siblings or primitive.
*/
static array unsafe_weak_copy(const array& other);
/** A unique identifier for an array. */
std::uintptr_t id() const {
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
@@ -339,11 +332,11 @@ class array {
return allocator::allocator().size(buffer());
}
// Return the shared pointer to the array::Data struct
const std::shared_ptr<Data>& data_shared_ptr() const {
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
template <typename T>
T* data() {
@@ -356,7 +349,7 @@ class array {
}
enum Status {
// The output of a computation which has not been scheduled.
// The ouptut of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`.
unscheduled,

View File

@@ -1,7 +1,6 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp

View File

@@ -44,14 +44,14 @@ inline void set_binary_op_output_data(
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc(b.data_size() * out.itemsize()),
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a);
} else {
out.set_data(
allocator::malloc(a.data_size() * out.itemsize()),
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc(a.data_size() * out.itemsize()),
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
break;
}

View File

@@ -1,24 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -1,11 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void broadcast(const array& in, array& out);
} // namespace mlx::core

View File

@@ -1,7 +1,6 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
@@ -43,6 +42,23 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out);
}
@@ -87,7 +103,7 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {

View File

@@ -188,7 +188,7 @@ void compiled_allocate_outputs(
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc(data_size * outputs[o].itemsize()),
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
@@ -211,7 +211,7 @@ void compiled_allocate_outputs(
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
}

View File

@@ -31,14 +31,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
return true;
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return false;
}
}

View File

@@ -99,11 +99,7 @@ inline std::pair<int, int> decompose_hadamard(int n) {
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
}
}
if (n > (1 << 26)) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where k <= 26");
}
return {n, m};
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core {
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto read_task = [out_ptr = out.data<char>(),
size = out.size(),
itemsize = out.itemsize(),

View File

@@ -48,12 +48,12 @@ inline void set_ternary_op_output_data(
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc(out.itemsize() * b.data_size()),
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
@@ -64,7 +64,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
break;
}

View File

@@ -40,8 +40,7 @@ add_dependencies(mlx cpu_compiled_preamble)
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
@@ -59,7 +58,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -75,8 +73,8 @@ target_sources(
if(MLX_BUILD_ACCELERATE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
endif()
if(IOS)

View File

@@ -11,7 +11,12 @@ namespace mlx::core {
namespace {
template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
void arg_reduce(
const array& in,
array& out,
const OpT& op,
int axis,
Stream stream) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
Strides strides = in.strides();
@@ -21,16 +26,28 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();
for (uint32_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0;
InT v = (*local_in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
op(j, (*local_in_ptr), &ind_v, &v);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in_ptr,
out_ptr,
axis_size,
axis_stride,
op = std::move(op),
shape = std::move(shape),
strides = std::move(strides),
size = out.size()]() {
for (uint32_t i = 0; i < size; ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0;
InT v = (*local_in_ptr);
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
op(j, (*local_in_ptr), &ind_v, &v);
}
out_ptr[i] = ind_v;
}
out_ptr[i] = ind_v;
}
});
}
template <typename InT>
@@ -38,7 +55,8 @@ void arg_reduce_dispatch(
const array& in,
array& out,
ArgReduce::ReduceType rtype,
int axis) {
int axis,
Stream stream) {
switch (rtype) {
case ArgReduce::ArgMin: {
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
@@ -47,7 +65,7 @@ void arg_reduce_dispatch(
(*ind_y) = ind_x;
}
};
arg_reduce<InT>(in, out, op, axis);
arg_reduce<InT>(in, out, op, axis, stream);
break;
}
case ArgReduce::ArgMax: {
@@ -57,7 +75,7 @@ void arg_reduce_dispatch(
(*ind_y) = ind_x;
}
};
arg_reduce<InT>(in, out, op, axis);
arg_reduce<InT>(in, out, op, axis, stream);
break;
}
}
@@ -68,59 +86,52 @@ void arg_reduce_dispatch(
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
reduce_type_ = reduce_type_,
axis_ = axis_]() mutable {
switch (in.dtype()) {
case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
break;
case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
break;
case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
break;
case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
break;
case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
break;
case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
break;
case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
break;
case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
break;
case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
break;
case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
break;
case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
break;
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;
}
});
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (in.dtype()) {
case bool_:
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_, stream());
break;
case uint8:
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_, stream());
break;
case uint16:
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_, stream());
break;
case uint32:
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_, stream());
break;
case uint64:
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_, stream());
break;
case int8:
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_, stream());
break;
case int16:
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_, stream());
break;
case int32:
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_, stream());
break;
case int64:
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_, stream());
break;
case float16:
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_, stream());
break;
case float32:
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_, stream());
break;
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_, stream());
break;
case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_, stream());
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_, stream());
break;
}
}
} // namespace mlx::core

View File

@@ -1,11 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/available.h"
namespace mlx::core::cpu {
bool is_available() {
return true;
}
} // namespace mlx::core::cpu

View File

@@ -1,9 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::cpu {
bool is_available();
} // namespace mlx::core::cpu

View File

@@ -8,7 +8,6 @@
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/binary_two.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -17,221 +16,51 @@ namespace mlx::core {
namespace {
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
}
});
}
template <typename Op>
void binary_int(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
void comparison_op(const array& a, const array& b, array& out) {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out);
break;
case float32:
binary_op<float, bool, Op>(a, b, out);
break;
case float64:
binary_op<double, bool, Op>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out);
break;
}
}
} // namespace
@@ -240,7 +69,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Add(), stream());
binary(a, b, out, detail::Add());
}
void DivMod::eval_cpu(
@@ -249,89 +78,70 @@ void DivMod::eval_cpu(
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out_a = array::unsafe_weak_copy(out_a),
out_b = array::unsafe_weak_copy(out_b),
bopt]() mutable {
auto integral_op = [](auto x, auto y) {
return std::make_pair(x / y, x % y);
};
auto float_op = [](auto x, auto y) {
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
};
switch (out_a.dtype()) {
case bool_:
binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);
case uint8:
binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case uint16:
binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case uint32:
binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case uint64:
binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int8:
binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int16:
binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int32:
binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case int64:
binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);
break;
case float16:
binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);
break;
case float32:
binary_op<float>(a, b, out_a, out_b, float_op, bopt);
break;
case float64:
binary_op<double>(a, b, out_a, out_b, float_op, bopt);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
});
auto integral_op = [](auto x, auto y) {
return std::make_pair(x / y, x % y);
};
auto float_op = [](auto x, auto y) {
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
};
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, integral_op);
case uint8:
binary_op<uint8_t>(a, b, outputs, integral_op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, integral_op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, integral_op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, integral_op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, integral_op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, integral_op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, integral_op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, integral_op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, float_op);
break;
case float32:
binary_op<float>(a, b, outputs, float_op);
break;
case float64:
binary_op<double>(a, b, outputs, float_op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, float_op);
break;
case complex64:
// Should never get here
throw std::runtime_error("[DivMod] Complex type not supported");
break;
}
}
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Divide(), stream());
binary(a, b, out, detail::Divide());
}
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Remainder(), stream());
binary(a, b, out, detail::Remainder());
}
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -339,143 +149,181 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& a = inputs[0];
auto& b = inputs[1];
if (equal_nan_) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case float16:
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
});
switch (a.dtype()) {
case float16:
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out);
break;
case float32:
binary_op<float, bool, detail::NaNEqual>(a, b, out);
break;
case float64:
binary_op<double, bool, detail::NaNEqual>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out);
break;
case complex64:
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out);
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
} else {
comparison_op(a, b, out, detail::Equal(), stream());
comparison_op<detail::Equal>(a, b, out);
}
}
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
comparison_op<detail::Greater>(inputs[0], inputs[1], out);
}
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
comparison_op<detail::GreaterEqual>(inputs[0], inputs[1], out);
}
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
comparison_op<detail::Less>(inputs[0], inputs[1], out);
}
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
comparison_op<detail::LessEqual>(inputs[0], inputs[1], out);
}
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary_float(a, b, out, detail::LogAddExp(), stream());
switch (out.dtype()) {
case float16:
binary_op<float16_t, detail::LogAddExp>(a, b, out);
break;
case float32:
binary_op<float, detail::LogAddExp>(a, b, out);
break;
case float64:
binary_op<double, detail::LogAddExp>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t, detail::LogAddExp>(a, b, out);
break;
default:
throw std::runtime_error(
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
}
}
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd(), stream());
binary(in1, in2, out, detail::LogicalAnd());
}
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr(), stream());
binary(in1, in2, out, detail::LogicalOr());
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Maximum(), stream());
binary(a, b, out, detail::Maximum());
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Minimum(), stream());
binary(a, b, out, detail::Minimum());
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Multiply(), stream());
binary(a, b, out, detail::Multiply());
}
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
comparison_op<detail::NotEqual>(inputs[0], inputs[1], out);
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Power(), stream());
binary(a, b, out, detail::Power());
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Subtract(), stream());
binary(a, b, out, detail::Subtract());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
binary_int(a, b, out, detail::BitwiseAnd(), stream());
dispatch_type(detail::BitwiseAnd());
break;
case BitwiseBinary::Or:
binary_int(a, b, out, detail::BitwiseOr(), stream());
dispatch_type(detail::BitwiseOr());
break;
case BitwiseBinary::Xor:
binary_int(a, b, out, detail::BitwiseXor(), stream());
dispatch_type(detail::BitwiseXor());
break;
case BitwiseBinary::LeftShift:
binary_int(a, b, out, detail::LeftShift(), stream());
dispatch_type(detail::LeftShift());
break;
case BitwiseBinary::RightShift:
binary_int(a, b, out, detail::RightShift(), stream());
dispatch_type(detail::RightShift());
break;
}
}
@@ -484,7 +332,23 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
binary_float(a, b, out, detail::ArcTan2(), stream());
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::ArcTan2());
break;
case float32:
binary_op<float>(a, b, out, detail::ArcTan2());
break;
case float64:
binary_op<double>(a, b, out, detail::ArcTan2());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
break;
default:
throw std::runtime_error(
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
}
}
} // namespace mlx::core

View File

@@ -3,9 +3,12 @@
#pragma once
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
#include "mlx/backend/cpu/simd/simd.h"
@@ -149,145 +152,218 @@ void binary_op_dispatch_dims(
}
template <typename T, typename U, typename Op>
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
void binary_op(const array& a, const array& b, array& out) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_ptr = out.data<U>();
if (bopt == BinaryOpType::ScalarScalar) {
*out_ptr = Op{}(*a_ptr, *b_ptr);
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
return;
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
auto& a_strides = new_strides[0];
auto& b_strides = new_strides[1];
auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([bopt,
a_ptr,
b_ptr,
out_ptr,
a_data_size = a.data_size(),
b_data_size = b.data_size(),
size = a.size(),
shape = a.shape(),
a_strides = a.strides(),
b_strides = b.strides(),
strides = out.strides()]() mutable {
if (bopt == BinaryOpType::ScalarScalar) {
*out_ptr = Op{}(*a_ptr, *b_ptr);
return;
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b_data_size);
return;
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b_strides);
auto ndim = new_shape.size();
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a_data_size);
return;
}
// Case 1: LxM and FxM where L and F are broadcastable and M is row
// contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, size);
return;
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
shape,
{std::move(a_strides), std::move(b_strides), std::move(strides)});
a_strides = new_strides[0];
b_strides = new_strides[1];
strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b_strides);
auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
if (dim == 0 || strides[dim - 1] < 16) {
bopt = BinaryOpType::General;
dim = ndim;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
if (dim == 0 || strides[dim - 1] < 16) {
bopt = BinaryOpType::General;
dim = ndim;
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
default:
binary_op_dispatch_dims<T, U, false, Op>(
a_ptr,
b_ptr,
out_ptr,
dim,
a.size(),
new_shape,
a_strides,
b_strides,
strides);
break;
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
size,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
size,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
a_ptr,
b_ptr,
out_ptr,
dim,
size,
new_shape,
a_strides,
b_strides,
strides);
break;
default:
binary_op_dispatch_dims<T, U, false, Op>(
a_ptr,
b_ptr,
out_ptr,
dim,
size,
new_shape,
a_strides,
b_strides,
strides);
break;
}
});
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
binary_op<T, T, Op>(a, b, out, bopt);
void binary_op(const array& a, const array& b, array& out) {
binary_op<T, T, Op>(a, b, out);
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
binary_op<T, T, Op>(a, b, out);
}
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op) {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out);
break;
case int8:
binary_op<int8_t, Op>(a, b, out);
break;
case int16:
binary_op<int16_t, Op>(a, b, out);
break;
case int32:
binary_op<int32_t, Op>(a, b, out);
break;
case int64:
binary_op<int64_t, Op>(a, b, out);
break;
case float16:
binary_op<float16_t, Op>(a, b, out);
break;
case float32:
binary_op<float, Op>(a, b, out);
break;
case float64:
binary_op<double, Op>(a, b, out);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out);
break;
}
}
} // namespace mlx::core

View File

@@ -4,6 +4,8 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -55,7 +57,14 @@ void binary_op_dispatch_dims(
const array& b,
array& out_a,
array& out_b,
Stream stream,
Op op) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out_a.strides()});
const T* a_ptr = a.data<T>();
@@ -63,101 +72,197 @@ void binary_op_dispatch_dims(
U* out_a_ptr = out_a.data<U>();
U* out_b_ptr = out_b.data<U>();
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
int ndim = shape.size();
switch (ndim) {
case 1:
binary_op_dims<T, U, Op, 1>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 2:
binary_op_dims<T, U, Op, 2>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = a.size(),
shape = std::move(shape),
strides = std::move(strides),
op = std::move(op)]() {
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& out_strides = strides[2];
int ndim = shape.size();
switch (ndim) {
case 1:
binary_op_dims<T, U, Op, 1>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 2:
binary_op_dims<T, U, Op, 2>(
a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_a_ptr + elem,
out_b_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
}
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < size; elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_a_ptr + elem,
out_b_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
}
});
}
template <typename T, typename U = T, typename Op>
void binary_op(
const array& a,
const array& b,
array& out_a,
array& out_b,
Op op,
BinaryOpType bopt) {
std::vector<array>& outputs,
Op op) {
auto bopt = get_binary_op_type(a, b);
auto& out_a = outputs[0];
auto& out_b = outputs[1];
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto stream = out_a.primitive().stream();
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::General) {
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, stream, op);
return;
}
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
auto a_ptr = a.data<T>();
auto b_ptr = b.data<T>();
auto out_a_ptr = out_a.data<U>();
auto out_b_ptr = out_b.data<U>();
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
encoder.dispatch(
[a_ptr, b_ptr, out_a_ptr, out_b_ptr, op = std::move(op)]() mutable {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
});
} else if (bopt == BinaryOpType::ScalarVector) {
for (size_t i = 0; i < b.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = b.size(),
op = std::move(op)]() mutable {
for (size_t i = 0; i < size; ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
});
} else if (bopt == BinaryOpType::VectorScalar) {
for (size_t i = 0; i < a.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
}
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = a.size(),
op = std::move(op)]() mutable {
for (size_t i = 0; i < size; ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
}
});
} else { // VectorVector
for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
b_ptr++;
}
encoder.dispatch([a_ptr,
b_ptr,
out_a_ptr,
out_b_ptr,
size = a.size(),
op = std::move(op)]() mutable {
for (size_t i = 0; i < size; ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
b_ptr++;
}
});
}
}
template <typename Op>
void binary(
const array& a,
const array& b,
std::vector<array>& outputs,
Op op) {
switch (outputs[0].dtype()) {
case bool_:
binary_op<bool>(a, b, outputs, op);
break;
case uint8:
binary_op<uint8_t>(a, b, outputs, op);
break;
case uint16:
binary_op<uint16_t>(a, b, outputs, op);
break;
case uint32:
binary_op<uint32_t>(a, b, outputs, op);
break;
case uint64:
binary_op<uint64_t>(a, b, outputs, op);
break;
case int8:
binary_op<int8_t>(a, b, outputs, op);
break;
case int16:
binary_op<int16_t>(a, b, outputs, op);
break;
case int32:
binary_op<int32_t>(a, b, outputs, op);
break;
case int64:
binary_op<int64_t>(a, b, outputs, op);
break;
case float16:
binary_op<float16_t>(a, b, outputs, op);
break;
case float32:
binary_op<float>(a, b, outputs, op);
break;
case float64:
binary_op<double>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op);
break;
case complex64:
binary_op<complex64_t>(a, b, outputs, op);
break;
}
}

View File

@@ -40,10 +40,7 @@ struct CompilerCache {
std::shared_mutex mtx;
};
static CompilerCache& cache() {
static CompilerCache cache_;
return cache_;
};
static CompilerCache cache{};
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
@@ -59,16 +56,14 @@ void* compile(
const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) {
{
std::shared_lock lock(cache().mtx);
if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
std::shared_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
}
std::unique_lock lock(cache().mtx);
if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
std::unique_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
std::string source_code = source_builder();
@@ -125,10 +120,10 @@ void* compile(
}
// load library
cache().libs.emplace_back(shared_lib_path);
cache.libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
@@ -136,7 +131,7 @@ void* compile(
<< dlerror();
throw std::runtime_error(msg.str());
}
cache().kernels.insert({kernel_name, fun});
cache.kernels.insert({kernel_name, fun});
return fun;
}

View File

@@ -921,7 +921,7 @@ void explicit_gemm_conv_1D_cpu(
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
@@ -1048,7 +1048,7 @@ void explicit_gemm_conv_2D_cpu(
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
@@ -1214,7 +1214,7 @@ void explicit_gemm_conv_ND_cpu(
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
@@ -1327,7 +1327,7 @@ void conv_3D_cpu(
} // namespace
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& in = inputs[0];
auto& wt = inputs[1];

View File

@@ -13,20 +13,29 @@ namespace mlx::core {
namespace {
template <typename SrcT, typename DstT>
void copy_single(const array& src, array& dst) {
void copy_single(const array& src, array& dst, Stream stream) {
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
auto size = dst.size();
auto val = static_cast<DstT>(src_ptr[0]);
std::fill_n(dst_ptr, size, val);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr, dst_ptr, size = dst.size()]() {
auto val = static_cast<DstT>(src_ptr[0]);
std::fill_n(dst_ptr, size, val);
});
}
template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst) {
void copy_vector(const array& src, array& dst, Stream stream) {
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
auto size = src.data_size();
std::copy(src_ptr, src_ptr + size, dst_ptr);
size_t size = src.data_size();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr, dst_ptr, size = src.data_size()]() {
std::copy(src_ptr, src_ptr + size, dst_ptr);
});
}
template <typename SrcT, typename DstT, int D>
@@ -57,6 +66,7 @@ template <typename SrcT, typename DstT>
void copy_general_general(
const array& src,
array& dst,
Stream stream,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
@@ -70,17 +80,47 @@ void copy_general_general(
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
auto o_offset_ptr =
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
auto size = src.size();
if (data_shape.empty()) {
auto val = static_cast<DstT>(*src_ptr);
*dst_ptr = val;
return;
}
auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
int ndim = shape.size();
if (ndim < 3) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch([src_ptr,
dst_ptr,
size = src.size(),
data_shape = data_shape,
i_strides = i_strides,
o_strides = o_strides,
i_offset_ptr,
o_offset_ptr]() mutable {
if (data_shape.empty()) {
auto val = static_cast<DstT>(*src_ptr);
*dst_ptr = val;
return;
}
auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
int ndim = shape.size();
if (ndim < 3) {
if (i_offset_ptr) {
src_ptr += i_offset_ptr[0];
}
if (o_offset_ptr) {
dst_ptr += o_offset_ptr[0];
}
if (ndim == 1) {
copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 3) {
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
}
return;
}
if (i_offset_ptr) {
src_ptr += i_offset_ptr[0];
}
@@ -88,47 +128,30 @@ void copy_general_general(
dst_ptr += o_offset_ptr[0];
}
if (ndim == 1) {
copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 2) {
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
} else if (ndim == 3) {
ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < size; elem += stride) {
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
strides[0],
strides[1],
ndim - 3);
in.step();
out.step();
}
return;
}
if (i_offset_ptr) {
src_ptr += i_offset_ptr[0];
}
if (o_offset_ptr) {
dst_ptr += o_offset_ptr[0];
}
ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < size; elem += stride) {
copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
strides[0],
strides[1],
ndim - 3);
in.step();
out.step();
}
});
}
template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) {
inline void copy_general_general(const array& src, array& dst, Stream stream) {
copy_general_general<SrcT, DstT>(
src,
dst,
stream,
src.shape(),
src.strides(),
dst.strides(),
@@ -142,6 +165,7 @@ template <typename SrcT, typename DstT>
void copy_general(
const array& src,
array& dst,
Stream stream,
const Shape& data_shape,
const Strides& i_strides,
const Strides&,
@@ -152,6 +176,7 @@ void copy_general(
copy_general_general<SrcT, DstT>(
src,
dst,
stream,
data_shape,
i_strides,
make_contiguous_strides(data_shape),
@@ -162,10 +187,11 @@ void copy_general(
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
inline void copy_general(const array& src, array& dst, Stream stream) {
copy_general_general<SrcT, DstT>(
src,
dst,
stream,
src.shape(),
src.strides(),
make_contiguous_strides(src.shape()),
@@ -176,67 +202,84 @@ inline void copy_general(const array& src, array& dst) {
}
template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
void copy(
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
switch (ctype) {
case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst);
copy_single<SrcT, DstT>(src, dst, stream);
return;
case CopyType::Vector:
copy_vector<SrcT, DstT>(src, dst);
copy_vector<SrcT, DstT>(src, dst, stream);
return;
case CopyType::General:
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
copy_general<SrcT, DstT>(src, dst, stream, std::forward<Args>(args)...);
return;
case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
copy_general_general<SrcT, DstT>(
src, dst, stream, std::forward<Args>(args)...);
return;
}
}
template <typename SrcT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
void copy(
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
switch (dst.dtype()) {
case bool_:
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, bool>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint8:
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint16:
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint16_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint32:
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint32_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint64:
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, uint64_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int8:
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int16:
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int32:
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int64:
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case float16:
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, float16_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case float32:
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, float>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case float64:
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, double>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, bfloat16_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case complex64:
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<SrcT, complex64_t>(
src, dst, ctype, stream, std::forward<Args>(args)...);
break;
}
}
@@ -246,49 +289,50 @@ inline void copy_inplace_dispatch(
const array& src,
array& dst,
CopyType ctype,
Stream stream,
Args&&... args) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
copy<bool>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint8:
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint16:
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<uint16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint32:
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<uint32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case uint64:
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<uint64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int8:
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int16:
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int32:
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case int64:
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case float16:
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<float16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case float32:
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
copy<float>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case float64:
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
copy<double>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<bfloat16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
case complex64:
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
copy<complex64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
break;
}
}
@@ -296,13 +340,7 @@ inline void copy_inplace_dispatch(
} // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
copy_inplace_dispatch(src, dst, ctype, stream);
}
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
@@ -330,47 +368,26 @@ void copy_inplace(
Stream stream,
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(src);
encoder.set_output_array(dst);
auto weak_copy_if_set = [](auto x) -> std::optional<array> {
if (x) {
return array::unsafe_weak_copy(*x);
} else {
return std::nullopt;
}
};
encoder.dispatch(
[src = array::unsafe_weak_copy(src),
dst = array::unsafe_weak_copy(dst),
data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
ctype,
dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
dynamic_i_offset,
dynamic_o_offset);
break;
case CopyType::Scalar:
case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype);
}
});
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
copy_inplace_dispatch(
src,
dst,
ctype,
stream,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset,
dynamic_i_offset,
dynamic_o_offset);
break;
case CopyType::Scalar:
case CopyType::Vector:
copy_inplace_dispatch(src, dst, ctype, stream);
}
}
} // namespace mlx::core

View File

@@ -30,7 +30,7 @@ void AllReduce::eval_cpu(
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
return in;
} else {
@@ -46,15 +46,8 @@ void AllReduce::eval_cpu(
case Sum:
distributed::detail::all_sum(group(), in, outputs[0], stream());
break;
case Max:
distributed::detail::all_max(group(), in, outputs[0], stream());
break;
case Min:
distributed::detail::all_min(group(), in, outputs[0], stream());
break;
default:
throw std::runtime_error(
"Only all reduce sum, min and max are supported for now");
throw std::runtime_error("Only all reduce sum is supported for now");
}
}
@@ -65,7 +58,7 @@ void AllGather::eval_cpu(
assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
distributed::detail::all_gather(group(), in, outputs[0], stream());
if (copied) {
auto& enc = cpu::get_command_encoder(stream());
@@ -94,7 +87,7 @@ void Recv::eval_cpu(
assert(inputs.size() == 0);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_, stream());
}

View File

@@ -55,8 +55,9 @@ void eigh_impl(
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
auto iwork_buf =
array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
syevd<T>(
&jobz,
@@ -97,7 +98,7 @@ void Eigh::eval_cpu(
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc(values.nbytes()));
values.set_data(allocator::malloc_or_wait(values.nbytes()));
copy(
a,

View File

@@ -9,9 +9,6 @@
namespace mlx::core::cpu {
// Number of dispatches per scheduler task
constexpr int DISPATCHES_PER_TASK = 10;
struct CommandEncoder {
CommandEncoder(Stream stream) : stream_(stream) {}
@@ -42,24 +39,13 @@ struct CommandEncoder {
template <class F, class... Args>
void dispatch(F&& f, Args&&... args) {
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
if (num_ops_ == 0) {
scheduler::notify_new_task(stream_);
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
task();
scheduler::notify_task_completion(s);
};
scheduler::enqueue(stream_, std::move(task_wrap));
} else {
scheduler::enqueue(stream_, std::move(task));
}
scheduler::enqueue(stream_, std::move(task));
}
private:
Stream stream_;
std::vector<array> temporaries_;
int num_ops_{0};
};
CommandEncoder& get_command_encoder(Stream stream);

View File

@@ -33,8 +33,12 @@ void eval(array& arr) {
buffers.erase(it);
}
auto& encoder = cpu::get_command_encoder(s);
encoder.dispatch([buffers = std::move(buffers),
temps = std::move(encoder.temporaries())]() {});
scheduler::notify_new_task(s);
encoder.dispatch([s,
buffers = std::move(buffers),
temps = std::move(encoder.temporaries())]() {
scheduler::notify_task_completion(s);
});
}
} // namespace mlx::core::cpu

View File

@@ -22,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
s *= out.itemsize();
}
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<size_t> shape;
if (out.dtype() == float32) {

View File

@@ -0,0 +1,27 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t*,
const bfloat16_t*,
bfloat16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}
} // namespace mlx::core

View File

@@ -0,0 +1,27 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t*,
const float16_t*,
float16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,45 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t* a,
const bfloat16_t* b,
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<bfloat16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -1,45 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t* a,
const float16_t* b,
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<float16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -1,139 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
inline int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
template <int block_size, typename T, typename AccT>
void load_block(
const T* in,
AccT* out,
int M,
int N,
int i,
int j,
bool transpose) {
if (transpose) {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[jj * block_size + ii] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
} else {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[ii * block_size + jj] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
}
}
template <typename T, typename AccT>
void simd_gemm(
const T* a,
const T* b,
T* c,
bool a_trans,
bool b_trans,
int M,
int N,
int K,
float alpha,
float beta) {
constexpr int block_size = 16;
constexpr int simd_size = simd::max_size<AccT>;
static_assert(
(block_size % simd_size) == 0,
"Block size must be divisible by SIMD size");
int last_k_block_size = K - block_size * (K / block_size);
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
for (int i = 0; i < ceildiv(M, block_size); i++) {
for (int j = 0; j < ceildiv(N, block_size); j++) {
AccT c_block[block_size * block_size] = {0.0};
AccT a_block[block_size * block_size];
AccT b_block[block_size * block_size];
int k = 0;
for (; k < K / block_size; k++) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
for (int kk = 0; kk < block_size; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
}
}
}
if (last_k_block_size) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
int kk = 0;
for (; kk < last_k_simd_block; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
for (; kk < last_k_block_size; ++kk) {
c_block[ii * block_size + jj] +=
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
}
}
}
}
// Store
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
if (beta != 0) {
c[c_idx] = static_cast<T>(
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
} else {
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
}
}
}
}
}
}
} // namespace mlx::core

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@ namespace mlx::core {
template <typename T>
void general_inv(T* inv, int N) {
int info;
auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
// Compute LU factorization.
getrf<T>(
/* m = */ &N,
@@ -49,7 +49,7 @@ void general_inv(T* inv, int N) {
}
const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
// Compute inverse.
getri<T>(

View File

@@ -1,140 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void logsumexp(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* current_in_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
normalizer += _exp;
current_in_ptr++;
}
// Normalize
*out_ptr = std::isinf(maximum)
? static_cast<T>(maximum)
: static_cast<T>(std::log(normalizer) + maximum);
}
});
}
} // namespace
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto s = stream();
auto& encoder = cpu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
switch (in.dtype()) {
case float32:
logsumexp<float, float>(in, out, stream());
break;
case float16:
logsumexp<float16_t, float>(in, out, stream());
break;
case bfloat16:
logsumexp<bfloat16_t, float>(in, out, stream());
break;
case float64:
logsumexp<double, double>(in, out, stream());
break;
default:
throw std::runtime_error(
"[logsumexp] only supports floating point types");
break;
}
}
} // namespace mlx::core

View File

@@ -30,7 +30,8 @@ void luf_impl(
auto strides = lu.strides();
strides[ndim - 1] = M;
strides[ndim - 2] = 1;
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
lu.set_data(
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace(
a,
lu,
@@ -43,8 +44,8 @@ void luf_impl(
stream);
auto a_ptr = lu.data<T>();
pivots.set_data(allocator::malloc(pivots.nbytes()));
row_indices.set_data(allocator::malloc(row_indices.nbytes()));
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
auto pivots_ptr = pivots.data<uint32_t>();
auto row_indices_ptr = row_indices.data<uint32_t>();
size_t num_matrices = a.size() / (M * N);

View File

@@ -59,7 +59,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
@@ -318,7 +318,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[GatherMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];

View File

@@ -115,7 +115,7 @@ void matmul_general(
}
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (inputs[0].shape(-1) == 0) {
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);

View File

@@ -21,7 +21,7 @@ namespace mlx::core {
void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General, out.primitive().stream());
} else {
shared_buffer_reshape(in, out_strides, out);
@@ -39,7 +39,7 @@ static std::pair<array, bool> compute_dynamic_offset(
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc(offset.itemsize()));
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
}
auto& encoder = cpu::get_command_encoder(stream);
@@ -124,7 +124,7 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
switch (out.dtype()) {
case bool_:
throw std::runtime_error("Bool type unsupported for arange.");
@@ -186,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
@@ -205,10 +205,8 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General, stream());
@@ -278,7 +276,7 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>();
@@ -337,7 +335,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
return;
}
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_inplace(
@@ -452,7 +450,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
} else {
auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc(tmp.nbytes()));
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);

View File

@@ -25,11 +25,12 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
auto strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
in.set_data(
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc(q.nbytes()));
r.set_data(allocator::malloc(r.nbytes()));
q.set_data(allocator::malloc_or_wait(q.nbytes()));
r.set_data(allocator::malloc_or_wait(r.nbytes()));
auto in_ptr = in.data<T>();
auto r_ptr = r.data<T>();
@@ -40,7 +41,8 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
encoder.set_output_array(r);
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
int num_reflectors = std::min(M, N);
auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
T optimal_work;
int lwork = -1;
@@ -51,7 +53,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc(sizeof(T) * lwork);
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
@@ -94,7 +96,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
&lwork,
&info);
lwork = optimal_work;
work = allocator::malloc(sizeof(T) * lwork);
work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {

View File

@@ -326,7 +326,8 @@ void _qmm_dispatch_typed(
const array& biases,
int bits,
int group_size,
bool transposed_w) {
bool transposed_w,
Stream stream) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
@@ -334,25 +335,56 @@ void _qmm_dispatch_typed(
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
for (int i = 0; i < batch_size; i++) {
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()),
M,
N,
K,
bits,
group_size,
transposed_w);
}
encoder.dispatch([out_ptr,
x_ptr,
w_ptr,
scales_ptr,
biases_ptr,
x_shape = x.shape(),
x_strides = x.strides(),
w_shape = w.shape(),
w_strides = w.strides(),
scales_shape = scales.shape(),
scales_strides = scales.strides(),
biases_shape = biases.shape(),
biases_strides = biases.strides(),
w_els,
g_els,
batch_size,
M,
N,
K,
bits,
group_size,
transposed_w] {
for (int i = 0; i < batch_size; i++) {
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x_shape, x_strides),
w_ptr + elem_to_loc(i * w_els, w_shape, w_strides),
scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides),
biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides),
M,
N,
K,
bits,
group_size,
transposed_w);
}
});
}
void _qmm_dispatch(
@@ -363,19 +395,20 @@ void _qmm_dispatch(
const array& biases,
int bits,
int group_size,
bool transposed_w) {
bool transposed_w,
Stream stream) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out, x, w, scales, biases, bits, group_size, transposed_w);
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w);
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w);
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
break;
default:
throw std::invalid_argument(
@@ -394,7 +427,8 @@ void _bs_qmm_dispatch_typed(
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w) {
bool transposed_w,
Stream stream) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
@@ -402,6 +436,15 @@ void _bs_qmm_dispatch_typed(
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
@@ -410,26 +453,53 @@ void _bs_qmm_dispatch_typed(
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
biases_ptr +
elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()),
M,
N,
K,
bits,
group_size,
transposed_w);
}
encoder.dispatch([out_ptr,
x_ptr,
w_ptr,
scales_ptr,
biases_ptr,
lhs_indices_ptr,
rhs_indices_ptr,
x_shape = x.shape(),
x_strides = x.strides(),
w_shape = w.shape(),
w_strides = w.strides(),
scales_shape = scales.shape(),
scales_strides = scales.strides(),
biases_shape = biases.shape(),
biases_strides = biases.strides(),
lhs_indices_shape = lhs_indices.shape(),
lhs_indices_strides = lhs_indices.strides(),
rhs_indices_shape = rhs_indices.shape(),
rhs_indices_strides = rhs_indices.strides(),
w_els,
g_els,
indices_size = lhs_indices.size(),
M,
N,
K,
bits,
group_size,
transposed_w]() {
for (int i = 0; i < indices_size; i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices_shape, lhs_indices_strides)];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices_shape, rhs_indices_strides)];
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides),
w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides),
scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides),
biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides),
M,
N,
K,
bits,
group_size,
transposed_w);
}
});
}
void _bs_qmm_dispatch(
@@ -442,7 +512,8 @@ void _bs_qmm_dispatch(
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w) {
bool transposed_w,
Stream stream) {
switch (x.dtype()) {
case float32:
_bs_qmm_dispatch_typed<float>(
@@ -455,7 +526,8 @@ void _bs_qmm_dispatch(
rhs_indices,
bits,
group_size,
transposed_w);
transposed_w,
stream);
break;
case float16:
_bs_qmm_dispatch_typed<float16_t>(
@@ -468,7 +540,8 @@ void _bs_qmm_dispatch(
rhs_indices,
bits,
group_size,
transposed_w);
transposed_w,
stream);
break;
case bfloat16:
_bs_qmm_dispatch_typed<bfloat16_t>(
@@ -481,7 +554,8 @@ void _bs_qmm_dispatch(
rhs_indices,
bits,
group_size,
transposed_w);
transposed_w,
stream);
break;
default:
throw std::invalid_argument(
@@ -515,25 +589,11 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_dispatch(
out, x, w, scales, biases, group_size_, bits_, transpose_, stream());
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporaries(std::move(temps));
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -565,39 +625,21 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_,
stream());
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporaries(std::move(temps));
}
template <typename T, typename U>
@@ -667,13 +709,27 @@ void dispatch_quantize(
array& scales,
array& biases,
int bits,
int group_size) {
int group_size,
Stream stream) {
auto w_ptr = w.data<T>();
auto out_ptr = out.data<U>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
quantize<T, U>(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w_ptr,
out_ptr,
scales_ptr,
biases_ptr,
bits,
group_size,
w_size = w.size()]() {
quantize<T, U>(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size);
});
}
void fast::AffineQuantize::eval_cpu(
@@ -691,55 +747,43 @@ void fast::AffineQuantize::eval_cpu(
auto [w, copied] = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
if (copied) {
encoder.add_temporary(w);
}
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w = array::unsafe_weak_copy(w),
out = array::unsafe_weak_copy(out),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_]() mutable {
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<float16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<bfloat16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<float, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
dispatch_quantize<float16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
}
});
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
} else {
dispatch_quantize<bfloat16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
} else {
dispatch_quantize<float, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
if (copied) {
cpu::get_command_encoder(stream()).add_temporary(w);
}
}
} // namespace mlx::core

View File

@@ -140,23 +140,34 @@ void reduction_op(
const array& x,
array& out,
const std::vector<int>& axes,
U init) {
U init,
Stream stream) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
ReductionPlan plan = get_reduction_plan(x, axes);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_output_array(out);
auto in_ptr = x.data<T>();
auto out_ptr = out.data<U>();
if (plan.type == ContiguousAllReduce) {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);
encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, size, Op{}, init);
});
return;
}
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
}
encoder.dispatch(
[in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable {
for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
}
});
return;
}
@@ -167,29 +178,40 @@ void reduction_op(
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(
in_ptr + offset, out_ptr, reduction_size, Op{}, init);
}
} else {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
Op{},
init);
},
plan.shape,
plan.strides);
}
}
} else {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
Op{},
init);
},
plan.shape,
plan.strides);
}
}
});
return;
}
@@ -198,12 +220,20 @@ void reduction_op(
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
for (int i = 0; i < out.size(); i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
in_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
reduction_stride,
size = out.size()]() mutable {
for (int i = 0; i < size; i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
in_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
});
return;
}
@@ -215,49 +245,67 @@ void reduction_op(
plan.strides.pop_back();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
out_ptr += reduction_stride;
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
reduction_stride,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) {
for (int i = 0; i < size; i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
out_ptr += reduction_stride;
}
} else {
for (int i = 0; i < size; i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
strided_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
Op{});
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
}
} else {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
strided_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
Op{});
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
}
});
return;
}
if (plan.type == GeneralReduce) {
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) {
val = Op{}(val, *(in_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
*out_ptr = val;
}
encoder.dispatch([in_ptr,
out_ptr,
init,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) {
val = Op{}(val, *(in_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
*out_ptr = val;
}
});
}
}
@@ -386,11 +434,12 @@ void reduce_dispatch_and_or(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
const std::vector<int>& axes,
Stream stream) {
if (rtype == Reduce::And) {
reduction_op<InT, bool, AndReduce>(in, out, axes, true);
reduction_op<InT, bool, AndReduce>(in, out, axes, true, stream);
} else {
reduction_op<InT, bool, OrReduce>(in, out, axes, false);
reduction_op<InT, bool, OrReduce>(in, out, axes, false, stream);
}
}
@@ -399,18 +448,19 @@ void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
const std::vector<int>& axes,
Stream stream) {
if (rtype == Reduce::Sum) {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0, stream);
} else {
reduction_op<InT, InT, SumReduce>(in, out, axes, 0);
reduction_op<InT, InT, SumReduce>(in, out, axes, 0, stream);
}
} else {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1, stream);
} else {
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1, stream);
}
}
}
@@ -420,144 +470,162 @@ void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
const std::vector<int>& axes,
Stream stream) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
reduction_op<InT, InT, MaxReduce>(in, out, axes, init, stream);
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT, MinReduce>(in, out, axes, init);
reduction_op<InT, InT, MinReduce>(in, out, axes, init, stream);
}
}
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
reduce_type_ = reduce_type_,
axes_ = axes_]() mutable {
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
case Reduce::Max:
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(
in, out, reduce_type_, axes_, stream());
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(
in, out, reduce_type_, axes_, stream());
break;
}
break;
}
});
case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(
in, out, reduce_type_, axes_, stream());
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(
in, out, reduce_type_, axes_, stream());
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(
in, out, reduce_type_, axes_, stream());
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(
in, out, reduce_type_, axes_, stream());
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(
in, out, reduce_type_, axes_, stream());
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(
in, out, reduce_type_, axes_, stream());
break;
case float32:
reduce_dispatch_sum_prod<float>(
in, out, reduce_type_, axes_, stream());
break;
case float64:
reduce_dispatch_sum_prod<double>(
in, out, reduce_type_, axes_, stream());
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(
in, out, reduce_type_, axes_, stream());
break;
}
break;
}
case Reduce::Max:
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_, stream());
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(
in, out, reduce_type_, axes_, stream());
break;
case int8:
reduce_dispatch_min_max<uint8_t>(
in, out, reduce_type_, axes_, stream());
break;
case int16:
reduce_dispatch_min_max<uint16_t>(
in, out, reduce_type_, axes_, stream());
break;
case int32:
reduce_dispatch_min_max<int32_t>(
in, out, reduce_type_, axes_, stream());
break;
case int64:
reduce_dispatch_min_max<int64_t>(
in, out, reduce_type_, axes_, stream());
break;
case float16:
reduce_dispatch_min_max<float16_t>(
in, out, reduce_type_, axes_, stream());
break;
case float32:
reduce_dispatch_min_max<float>(
in, out, reduce_type_, axes_, stream());
break;
case float64:
reduce_dispatch_min_max<double>(
in, out, reduce_type_, axes_, stream());
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(
in, out, reduce_type_, axes_, stream());
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(
in, out, reduce_type_, axes_, stream());
break;
}
break;
}
}
}
} // namespace mlx::core

View File

@@ -3,7 +3,6 @@
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
@@ -161,29 +160,38 @@ void scan_op(
bool reverse,
bool inclusive,
const Op& op,
U init) {
U init,
Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (in.flags().row_contiguous) {
if (in.strides()[axis] == 1) {
contiguous_scan(
in.data<T>(),
out.data<U>(),
in.size() / in.shape(axis),
in.shape(axis),
reverse,
inclusive,
op,
init);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<U>(),
count = in.size() / in.shape(axis),
stride = in.shape(axis),
reverse,
inclusive,
op = std::move(op),
init]() {
contiguous_scan(
in_ptr, out_ptr, count, stride, reverse, inclusive, op, init);
});
} else {
strided_scan(
in.data<T>(),
out.data<U>(),
in.size() / in.shape(axis) / in.strides()[axis],
in.shape(axis),
in.strides()[axis],
reverse,
inclusive,
op,
init);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<U>(),
count = in.size() / in.shape(axis) / in.strides()[axis],
size = in.shape(axis),
stride = in.strides()[axis],
reverse,
inclusive,
op = std::move(op),
init]() {
strided_scan(
in_ptr, out_ptr, count, size, stride, reverse, inclusive, op, init);
});
}
} else {
throw std::runtime_error("Scan op supports only contiguous inputs");
@@ -197,18 +205,19 @@ void scan_dispatch(
array& out,
int axis,
bool reverse,
bool inclusive) {
bool inclusive,
Stream stream) {
switch (rtype) {
case Scan::Sum: {
auto op = [](U y, T x) { return y + x; };
auto init = static_cast<U>(0);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
break;
}
case Scan::Prod: {
auto op = [](U y, T x) { return y * x; };
auto init = static_cast<U>(1);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
break;
}
case Scan::Min: {
@@ -216,7 +225,7 @@ void scan_dispatch(
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
break;
}
case Scan::Max: {
@@ -224,17 +233,7 @@ void scan_dispatch(
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
case Scan::LogAddExp: {
auto op = [](U a, T b) {
return detail::LogAddExp{}(a, static_cast<U>(b));
};
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
break;
}
}
@@ -245,96 +244,88 @@ void scan_dispatch(
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& encoder = cpu::get_command_encoder(stream());
// Ensure contiguity
auto in = inputs[0];
bool copied = false;
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General, stream());
in = arr_copy;
encoder.add_temporary(arr_copy);
copied = true;
}
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_,
reduce_type_ = reduce_type_,
reverse_ = reverse_,
inclusive_ = inclusive_]() mutable {
switch (in.dtype()) {
case bool_: {
// We could do a full dtype x dtype switch but this is the only case
// where we accumulate in a different type, for now.
//
// TODO: If we add the option to accumulate floats in higher precision
// floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
} else {
scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
}
break;
switch (in.dtype()) {
case bool_: {
// We could do a full dtype x dtype switch but this is the only case
// where we accumulate in a different type, for now.
//
// TODO: If we add the option to accumulate floats in higher precision
// floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
} else {
scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
}
case uint8:
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
scan_dispatch<complex64_t, complex64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
break;
}
});
case uint8:
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
}
if (copied) {
cpu::get_command_encoder(stream()).add_temporary(std::move(in));
}
}
} // namespace mlx::core

View File

@@ -16,70 +16,51 @@ void select_op(
const array& b,
const array& c,
array& out,
Op op,
Stream stream) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
c = array::unsafe_weak_copy(c),
out = array::unsafe_weak_copy(out),
op,
topt]() mutable {
switch (out.dtype()) {
case bool_:
ternary_op<bool, bool, bool, bool>(a, b, c, out, op, topt);
break;
case uint8:
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op, topt);
break;
case uint16:
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op, topt);
break;
case uint32:
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op, topt);
break;
case uint64:
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op, topt);
break;
case int8:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op, topt);
break;
case int16:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op, topt);
break;
case int32:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op, topt);
break;
case int64:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op, topt);
break;
case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(
a, b, c, out, op, topt);
break;
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op, topt);
break;
case float64:
ternary_op<bool, double, double, double>(a, b, c, out, op, topt);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(
a, b, c, out, op, topt);
break;
case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(
a, b, c, out, op, topt);
break;
}
});
Op op) {
switch (out.dtype()) {
case bool_:
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
break;
case uint8:
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
break;
case uint16:
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
break;
case uint32:
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
break;
case uint64:
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
break;
case int8:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
break;
case int16:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
break;
case int32:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
break;
case int64:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
break;
case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
break;
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op);
break;
case float64:
ternary_op<bool, double, double, double>(a, b, c, out, op);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
break;
case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
break;
}
}
} // namespace
@@ -89,7 +70,7 @@ void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
const auto& condition = inputs[0];
const auto& a = inputs[1];
const auto& b = inputs[2];
select_op(condition, a, b, out, detail::Select(), stream());
select_op(condition, a, b, out, detail::Select());
}
} // namespace mlx::core

View File

@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
#endif
template <>
inline constexpr int max_size<float16_t> = N;
static constexpr int max_size<float16_t> = N;
#define SIMD_FP16_DEFAULT_UNARY(op) \
template <> \

View File

@@ -83,25 +83,25 @@ struct Simd {
// Values chosen based on benchmarks on M3 Max
// TODO: consider choosing these more optimally
template <>
inline constexpr int max_size<int8_t> = 16;
static constexpr int max_size<int8_t> = 16;
template <>
inline constexpr int max_size<int16_t> = 16;
static constexpr int max_size<int16_t> = 16;
template <>
inline constexpr int max_size<int> = 8;
static constexpr int max_size<int> = 8;
template <>
inline constexpr int max_size<int64_t> = 4;
static constexpr int max_size<int64_t> = 4;
template <>
inline constexpr int max_size<uint8_t> = 16;
static constexpr int max_size<uint8_t> = 16;
template <>
inline constexpr int max_size<uint16_t> = 16;
static constexpr int max_size<uint16_t> = 16;
template <>
inline constexpr int max_size<uint32_t> = 8;
static constexpr int max_size<uint32_t> = 8;
template <>
inline constexpr int max_size<uint64_t> = 4;
static constexpr int max_size<uint64_t> = 4;
template <>
inline constexpr int max_size<float> = 8;
static constexpr int max_size<float> = 8;
template <>
inline constexpr int max_size<double> = 4;
static constexpr int max_size<double> = 4;
#define SIMD_DEFAULT_UNARY(name, op) \
template <typename T, int N> \

View File

@@ -87,45 +87,14 @@ DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh)
DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log1p(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto x = in.value.real();
auto y = in.value.imag();
auto zabs = std::abs(in.value);
auto theta = std::atan2(y, x + 1);
if (zabs < 0.5) {
auto r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return Simd<T, 1>{T{x, theta}};
}
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
} else {
auto z0 = std::hypot(x + 1, y);
return Simd<T, 1>{T{std::log(z0), theta}};
}
} else {
return Simd<T, 1>{std::log1p(in.value)};
}
}
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto out = std::log(in.value);
auto scale = decltype(out.real())(M_LN2);
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
} else {
return Simd<T, 1>{std::log2(in.value)};
}
}
template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value;

View File

@@ -119,12 +119,17 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous
auto set_output = [s = stream(), &out](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc(x.data_size() * x.itemsize()),
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
@@ -141,6 +146,18 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
auto in = set_output(inputs[0]);
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::runtime_error(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<float, float>(in, out, stream());
break;
@@ -161,9 +178,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
case float64:
softmax<double, double>(in, out, stream());
break;
default:
throw std::runtime_error(
"[softmax] Only defined for floating point types.");
case complex64:
throw std::invalid_argument(
"[Softmax] Not yet implemented for complex64");
break;
}
}

View File

@@ -105,11 +105,15 @@ struct StridedIterator {
};
template <typename T>
void sort(array& out, int axis) {
void sort(const array& in, array& out, int axis, Stream stream) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream);
// Get axis, shape and stride info
axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
axis = axis < 0 ? axis + in.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -123,20 +127,30 @@ void sort(array& out, int axis) {
// Perform sorting in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<T>(),
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride]() mutable {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed);
src_it.step();
}
std::stable_sort(st, ed);
src_it.step();
}
});
}
template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis) {
void argsort(const array& in, array& out, int axis, Stream stream) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
@@ -162,69 +176,99 @@ void argsort(const array& in, array& out, int axis) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride]() mutable {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();
out_it.step();
in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
});
}
template <typename T>
void partition(array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
void partition(const array& in, array& out, int axis, int kth, Stream stream) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream);
auto remaining_shape = out.shape();
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = out.strides();
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto axis_stride = out.strides()[axis];
int axis_size = out.shape(axis);
auto axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
src_it.step();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<T>(),
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride,
kth]() mutable {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
src_it.step();
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed);
}
std::nth_element(st, md, ed);
}
});
}
template <typename T, typename IdxT = uint32_t>
void argpartition(const array& in, array& out, int axis, int kth) {
void argpartition(
const array& in,
array& out,
int axis,
int kth,
Stream stream) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
@@ -253,32 +297,42 @@ void argpartition(const array& in, array& out, int axis, int kth) {
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride,
kth]() mutable {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();
out_it.step();
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size);
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
});
}
} // namespace
@@ -287,184 +341,144 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Allocate output
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_]() mutable {
switch (in.dtype()) {
case bool_:
return argsort<bool>(in, out, axis_);
case uint8:
return argsort<uint8_t>(in, out, axis_);
case uint16:
return argsort<uint16_t>(in, out, axis_);
case uint32:
return argsort<uint32_t>(in, out, axis_);
case uint64:
return argsort<uint64_t>(in, out, axis_);
case int8:
return argsort<int8_t>(in, out, axis_);
case int16:
return argsort<int16_t>(in, out, axis_);
case int32:
return argsort<int32_t>(in, out, axis_);
case int64:
return argsort<int64_t>(in, out, axis_);
case float32:
return argsort<float>(in, out, axis_);
case float64:
return argsort<double>(in, out, axis_);
case float16:
return argsort<float16_t>(in, out, axis_);
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_);
case complex64:
return argsort<complex64_t>(in, out, axis_);
}
});
switch (in.dtype()) {
case bool_:
return argsort<bool>(in, out, axis_, stream());
case uint8:
return argsort<uint8_t>(in, out, axis_, stream());
case uint16:
return argsort<uint16_t>(in, out, axis_, stream());
case uint32:
return argsort<uint32_t>(in, out, axis_, stream());
case uint64:
return argsort<uint64_t>(in, out, axis_, stream());
case int8:
return argsort<int8_t>(in, out, axis_, stream());
case int16:
return argsort<int16_t>(in, out, axis_, stream());
case int32:
return argsort<int32_t>(in, out, axis_, stream());
case int64:
return argsort<int64_t>(in, out, axis_, stream());
case float32:
return argsort<float>(in, out, axis_, stream());
case float64:
return argsort<double>(in, out, axis_, stream());
case float16:
return argsort<float16_t>(in, out, axis_, stream());
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_, stream());
case complex64:
return argsort<complex64_t>(in, out, axis_, stream());
}
}
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch(
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
switch (out.dtype()) {
case bool_:
return sort<bool>(out, axis_);
case uint8:
return sort<uint8_t>(out, axis_);
case uint16:
return sort<uint16_t>(out, axis_);
case uint32:
return sort<uint32_t>(out, axis_);
case uint64:
return sort<uint64_t>(out, axis_);
case int8:
return sort<int8_t>(out, axis_);
case int16:
return sort<int16_t>(out, axis_);
case int32:
return sort<int32_t>(out, axis_);
case int64:
return sort<int64_t>(out, axis_);
case float32:
return sort<float>(out, axis_);
case float64:
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(out, axis_);
}
});
switch (in.dtype()) {
case bool_:
return sort<bool>(in, out, axis_, stream());
case uint8:
return sort<uint8_t>(in, out, axis_, stream());
case uint16:
return sort<uint16_t>(in, out, axis_, stream());
case uint32:
return sort<uint32_t>(in, out, axis_, stream());
case uint64:
return sort<uint64_t>(in, out, axis_, stream());
case int8:
return sort<int8_t>(in, out, axis_, stream());
case int16:
return sort<int16_t>(in, out, axis_, stream());
case int32:
return sort<int32_t>(in, out, axis_, stream());
case int64:
return sort<int64_t>(in, out, axis_, stream());
case float32:
return sort<float>(in, out, axis_, stream());
case float64:
return sort<double>(in, out, axis_, stream());
case float16:
return sort<float16_t>(in, out, axis_, stream());
case bfloat16:
return sort<bfloat16_t>(in, out, axis_, stream());
case complex64:
return sort<complex64_t>(in, out, axis_, stream());
}
}
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Allocate output
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_,
kth_ = kth_]() mutable {
switch (in.dtype()) {
case bool_:
return argpartition<bool>(in, out, axis_, kth_);
case uint8:
return argpartition<uint8_t>(in, out, axis_, kth_);
case uint16:
return argpartition<uint16_t>(in, out, axis_, kth_);
case uint32:
return argpartition<uint32_t>(in, out, axis_, kth_);
case uint64:
return argpartition<uint64_t>(in, out, axis_, kth_);
case int8:
return argpartition<int8_t>(in, out, axis_, kth_);
case int16:
return argpartition<int16_t>(in, out, axis_, kth_);
case int32:
return argpartition<int32_t>(in, out, axis_, kth_);
case int64:
return argpartition<int64_t>(in, out, axis_, kth_);
case float32:
return argpartition<float>(in, out, axis_, kth_);
case float64:
return argpartition<double>(in, out, axis_, kth_);
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_);
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_);
}
});
switch (in.dtype()) {
case bool_:
return argpartition<bool>(in, out, axis_, kth_, stream());
case uint8:
return argpartition<uint8_t>(in, out, axis_, kth_, stream());
case uint16:
return argpartition<uint16_t>(in, out, axis_, kth_, stream());
case uint32:
return argpartition<uint32_t>(in, out, axis_, kth_, stream());
case uint64:
return argpartition<uint64_t>(in, out, axis_, kth_, stream());
case int8:
return argpartition<int8_t>(in, out, axis_, kth_, stream());
case int16:
return argpartition<int16_t>(in, out, axis_, kth_, stream());
case int32:
return argpartition<int32_t>(in, out, axis_, kth_, stream());
case int64:
return argpartition<int64_t>(in, out, axis_, kth_, stream());
case float32:
return argpartition<float>(in, out, axis_, kth_, stream());
case float64:
return argpartition<double>(in, out, axis_, kth_, stream());
case float16:
return argpartition<float16_t>(in, out, axis_, kth_, stream());
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_, stream());
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_, stream());
}
}
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
axis_ = axis_,
kth_ = kth_]() mutable {
switch (out.dtype()) {
case bool_:
return partition<bool>(out, axis_, kth_);
case uint8:
return partition<uint8_t>(out, axis_, kth_);
case uint16:
return partition<uint16_t>(out, axis_, kth_);
case uint32:
return partition<uint32_t>(out, axis_, kth_);
case uint64:
return partition<uint64_t>(out, axis_, kth_);
case int8:
return partition<int8_t>(out, axis_, kth_);
case int16:
return partition<int16_t>(out, axis_, kth_);
case int32:
return partition<int32_t>(out, axis_, kth_);
case int64:
return partition<int64_t>(out, axis_, kth_);
case float32:
return partition<float>(out, axis_, kth_);
case float64:
return partition<double>(out, axis_, kth_);
case float16:
return partition<float16_t>(out, axis_, kth_);
case bfloat16:
return partition<bfloat16_t>(out, axis_, kth_);
case complex64:
return partition<complex64_t>(out, axis_, kth_);
}
});
switch (in.dtype()) {
case bool_:
return partition<bool>(in, out, axis_, kth_, stream());
case uint8:
return partition<uint8_t>(in, out, axis_, kth_, stream());
case uint16:
return partition<uint16_t>(in, out, axis_, kth_, stream());
case uint32:
return partition<uint32_t>(in, out, axis_, kth_, stream());
case uint64:
return partition<uint64_t>(in, out, axis_, kth_, stream());
case int8:
return partition<int8_t>(in, out, axis_, kth_, stream());
case int16:
return partition<int16_t>(in, out, axis_, kth_, stream());
case int32:
return partition<int32_t>(in, out, axis_, kth_, stream());
case int64:
return partition<int64_t>(in, out, axis_, kth_, stream());
case float32:
return partition<float>(in, out, axis_, kth_, stream());
case float64:
return partition<double>(in, out, axis_, kth_, stream());
case float16:
return partition<float16_t>(in, out, axis_, kth_, stream());
case bfloat16:
return partition<bfloat16_t>(in, out, axis_, kth_, stream());
case complex64:
return partition<complex64_t>(in, out, axis_, kth_, stream());
}
}
} // namespace mlx::core

View File

@@ -50,9 +50,9 @@ void svd_impl(
array& s = outputs[1];
array& vt = outputs[2];
u.set_data(allocator::malloc(u.nbytes()));
s.set_data(allocator::malloc(s.nbytes()));
vt.set_data(allocator::malloc(vt.nbytes()));
u.set_data(allocator::malloc_or_wait(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
encoder.set_output_array(u);
encoder.set_output_array(s);
@@ -64,7 +64,7 @@ void svd_impl(
} else {
array& s = outputs[0];
s.set_data(allocator::malloc(s.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes()));
encoder.set_output_array(s);
@@ -91,7 +91,7 @@ void svd_impl(
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
@@ -132,7 +132,7 @@ void svd_impl(
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {

View File

@@ -1,10 +1,12 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/ternary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -126,28 +128,57 @@ void ternary_op(
const array& b,
const array& c,
array& out,
Op op,
TernaryOpType topt) {
Op op) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
if (topt == TernaryOpType::ScalarScalarScalar) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
encoder.dispatch(
[a_ptr, b_ptr, c_ptr, out_ptr, op = std::move(op)]() mutable {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
});
} else if (topt == TernaryOpType::VectorVectorVector) {
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
encoder.dispatch([a_ptr,
b_ptr,
c_ptr,
out_ptr,
op = std::move(op),
size = out.size()]() mutable {
for (size_t i = 0; i < size; ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
});
} else {
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
ternary_op_dispatch_dims<T1, T2, T3, U>(
a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides);
encoder.dispatch(
[a_ptr,
b_ptr,
c_ptr,
out_ptr,
op = std::move(op),
size = out.size(),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
ternary_op_dispatch_dims<T1, T2, T3, U>(
a_ptr, b_ptr, c_ptr, out_ptr, op, size, shape, strides);
});
}
}

View File

@@ -1,8 +1,5 @@
// Copyright © 2024 Apple Inc.
// Required for using M_LN2 in MSVC.
#define _USE_MATH_DEFINES
#include <cassert>
#include "mlx/backend/cpu/unary.h"
@@ -17,57 +14,88 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
unary_signed(in, out, detail::Abs(), stream());
auto op = detail::Abs{};
switch (out.dtype()) {
case int8:
unary_op<int8_t>(in, out, op);
break;
case int16:
unary_op<int16_t>(in, out, op);
break;
case int32:
unary_op<int32_t>(in, out, op);
break;
case int64:
unary_op<int64_t>(in, out, op);
break;
case float16:
unary_op<float16_t>(in, out, op);
break;
case float32:
unary_op<float>(in, out, op);
break;
case float64:
unary_op<double>(in, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, op);
break;
case complex64:
unary_op<complex64_t>(in, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
}
}
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCos(), stream());
unary_fp(in, out, detail::ArcCos());
}
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCosh(), stream());
unary_fp(in, out, detail::ArcCosh());
}
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSin(), stream());
unary_fp(in, out, detail::ArcSin());
}
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSinh(), stream());
unary_fp(in, out, detail::ArcSinh());
}
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTan(), stream());
unary_fp(in, out, detail::ArcTan());
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTanh(), stream());
unary_fp(in, out, detail::ArcTanh());
}
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_int(in, out, detail::BitwiseInvert(), stream());
unary_int(in, out, detail::BitwiseInvert());
}
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil(), stream());
unary_fp(in, out, detail::Ceil());
} else {
// No-op integer types
out.copy_shared_buffer(in);
@@ -76,50 +104,84 @@ void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
unary_complex(inputs[0], out, detail::Conjugate(), stream());
unary_op<complex64_t>(inputs[0], out, detail::Conjugate());
}
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Cos(), stream());
unary_fp(in, out, detail::Cos());
}
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Cosh(), stream());
unary_fp(in, out, detail::Cosh());
}
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_real_fp(in, out, detail::Erf(), stream());
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::Erf());
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case float64:
unary_op<double>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
}
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_real_fp(in, out, detail::ErfInv(), stream());
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::ErfInv());
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case float64:
unary_op<double>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Exp(), stream());
unary_fp(in, out, detail::Exp());
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Expm1(), stream());
unary_fp(in, out, detail::Expm1());
}
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor(), stream());
unary_fp(in, out, detail::Floor());
} else {
// No-op integer types
out.copy_shared_buffer(in);
@@ -127,7 +189,7 @@ void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_complex_to_float(inputs[0], out, detail::Imag(), stream());
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -135,13 +197,13 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0];
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log(), stream());
unary_fp(in, out, detail::Log());
break;
case Base::two:
unary_fp(in, out, detail::Log2(), stream());
unary_fp(in, out, detail::Log2());
break;
case Base::ten:
unary_fp(in, out, detail::Log10(), stream());
unary_fp(in, out, detail::Log10());
break;
}
}
@@ -149,30 +211,30 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Log1p(), stream());
unary_fp(in, out, detail::Log1p());
}
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::LogicalNot(), stream());
unary(in, out, detail::LogicalNot());
}
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Negative(), stream());
unary(in, out, detail::Negative());
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_complex_to_float(inputs[0], out, detail::Real(), stream());
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round(), stream());
unary_fp(in, out, detail::Round());
} else {
// No-op integer types
out.copy_shared_buffer(in);
@@ -182,7 +244,7 @@ void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sigmoid(), stream());
unary_fp(in, out, detail::Sigmoid());
}
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -191,48 +253,48 @@ void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
if (in.dtype() == bool_) {
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Sign(), stream());
unary(in, out, detail::Sign());
}
}
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sin(), stream());
unary_fp(in, out, detail::Sin());
}
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sinh(), stream());
unary_fp(in, out, detail::Sinh());
}
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Square(), stream());
unary(in, out, detail::Square());
}
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (recip_) {
unary_fp(in, out, detail::Rsqrt(), stream());
unary_fp(in, out, detail::Rsqrt());
} else {
unary_fp(in, out, detail::Sqrt(), stream());
unary_fp(in, out, detail::Sqrt());
}
}
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Tan(), stream());
unary_fp(in, out, detail::Tan());
}
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Tanh(), stream());
unary_fp(in, out, detail::Tanh());
}
} // namespace mlx::core

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -18,13 +19,13 @@ void set_unary_output_data(const array& in, array& out) {
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc(size * out.itemsize()),
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
}
@@ -38,263 +39,156 @@ void unary_op(const T* a, U* out, size_t shape, size_t stride) {
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op) {
set_unary_output_data(a, out);
const T* src = a.data<T>();
U* dst = out.data<U>();
auto ndim = a.ndim();
if (a.flags().contiguous) {
auto size = a.data_size();
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src)));
size -= N;
src += N;
dst += N;
}
while (size > 0) {
*dst = Op{}(*src);
size--;
dst++;
src++;
}
} else {
size_t shape = ndim > 0 ? a.shape().back() : 1;
size_t stride = ndim > 0 ? a.strides().back() : 1;
if (ndim <= 1) {
unary_op<T, U, Op>(src, dst, shape, stride);
return;
}
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step();
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([src,
dst,
contig = a.flags().contiguous,
data_size = a.data_size(),
size = a.size(),
shapes = a.shape(),
strides = a.strides()]() mutable {
auto ndim = shapes.size();
if (contig) {
constexpr int N = simd::max_size<T>;
while (data_size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src)));
data_size -= N;
src += N;
dst += N;
}
while (data_size > 0) {
*dst = Op{}(*src);
data_size--;
dst++;
src++;
}
} else {
size_t shape = ndim > 0 ? shapes.back() : 1;
size_t stride = ndim > 0 ? strides.back() : 1;
if (ndim <= 1) {
unary_op<T, U, Op>(src, dst, shape, stride);
return;
}
auto it = ContiguousIterator(shapes, strides, ndim - 1);
for (size_t elem = 0; elem < size; elem += shape) {
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step();
}
}
});
}
template <typename Op>
void unary(const array& a, array& out, Op op) {
switch (out.dtype()) {
case bool_:
unary_op<bool>(a, out, op);
break;
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
}
}
template <typename Op>
void unary(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case bool_:
unary_op<bool>(a, out, op);
break;
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
}
});
void unary_fp(const array& a, array& out, Op op) {
switch (out.dtype()) {
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_fp] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
}
template <typename Op>
void unary_real_fp(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_real] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
}
template <typename Op>
void unary_fp(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_fp] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
}
template <typename Op>
void unary_signed(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
case float16:
unary_op<float16_t>(a, out, op);
break;
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
});
}
template <typename Op>
void unary_complex(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable { unary_op<complex64_t>(a, out, op); });
}
template <typename Op>
void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch(
[a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable { unary_op<complex64_t, float>(a, out, op); });
}
template <typename Op>
void unary_int(const array& a, array& out, Op op, Stream stream) {
set_unary_output_data(a, out);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
out = array::unsafe_weak_copy(out),
op = op]() mutable {
switch (out.dtype()) {
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_int] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
});
void unary_int(const array& a, array& out, Op op) {
switch (out.dtype()) {
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_int] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
}
} // namespace mlx::core

View File

@@ -86,14 +86,13 @@ struct Sign {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
auto z = Simd<T, N>{0};
auto o = Simd<T, N>{1};
auto m = Simd<T, N>{-1};
if constexpr (std::is_unsigned_v<T>) {
return simd::select(x == z, z, o);
return x != z;
} else if constexpr (std::is_same_v<T, complex64_t>) {
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
} else {
return simd::select(x < z, m, simd::select(x > z, o, z));
return simd::select(
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
}
}
SINGLE()

View File

@@ -1,5 +0,0 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)

View File

@@ -1,9 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::gpu {
bool is_available();
} // namespace mlx::core::gpu

View File

@@ -1,49 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/gpu/copy.h"
#include "mlx/primitives.h"
#include <cassert>
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& i_strides,
int64_t i_offset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
} // namespace mlx::core

View File

@@ -1,217 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/primitives.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include <cassert>
#define MLX_PROFILER_RANGE(message)
namespace mlx::core {
namespace {
void reshape(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc(out.nbytes()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
eval(inputs, out);
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("AsType::eval_gpu");
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
copy_gpu(inputs[0], out, ctype);
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Broadcast::eval_gpu");
eval(inputs, out);
}
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("BroadcastAxes::eval_gpu");
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Concatenate::eval_gpu");
concatenate_gpu(inputs, out, axis_, stream());
}
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Contiguous::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Copy::eval_gpu");
eval(inputs, out);
}
void CustomTransforms::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("CustomTransforms::eval_gpu");
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("Depends::eval_gpu");
eval(inputs, outputs);
}
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
eval(inputs, out);
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Full::eval_gpu");
auto in = inputs[0];
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy_gpu(in, out, ctype);
}
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Flatten::eval_gpu");
reshape(inputs[0], out, stream());
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("NumberOfElements::eval_gpu");
eval(inputs, out);
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
pad_gpu(in, val, out, axes_, low_pad_size_, stream());
}
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Reshape::eval_gpu");
reshape(inputs[0], out, stream());
}
void Split::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
MLX_PROFILER_RANGE("Split::eval_gpu");
eval(inputs, outputs);
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Slice::eval_gpu");
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
slice_gpu(in, out, start_indices_, strides_, stream());
}
void Squeeze::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Squeeze::eval_gpu");
eval(inputs, out);
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("StopGradient::eval_gpu");
eval(inputs, out);
}
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Transpose::eval_gpu");
eval(inputs, out);
}
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
reshape(inputs[0], out, stream());
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("View::eval_gpu");
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc(tmp.nbytes()));
copy_gpu_inplace(in, tmp, CopyType::General, stream());
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@@ -1,44 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
namespace mlx::core {
void slice_gpu(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
slice(in, out, start_indices, strides);
}
void pad_gpu(
const array& in,
const array& val,
array& out,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s) {
// Fill output with val
fill_gpu(val, out, s);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes.size(); i++) {
auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
data_offset += out.strides()[ax] * low_pad_size[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
}
} // namespace mlx::core

View File

@@ -47,7 +47,6 @@ if(MLX_METAL_JIT)
make_jit_source(binary)
make_jit_source(binary_two)
make_jit_source(fft kernels/fft/radix.h kernels/fft/readwrite.h)
make_jit_source(logsumexp)
make_jit_source(ternary)
make_jit_source(softmax)
make_jit_source(scan)
@@ -61,7 +60,6 @@ if(MLX_METAL_JIT)
kernels/steel/gemm/transforms.h)
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(steel/gemm/kernels/steel_gemm_masked kernels/steel/defines.h)
make_jit_source(steel/gemm/kernels/steel_gemm_gather)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source(
steel/conv/conv
@@ -93,12 +91,10 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include "mlx/memory.h"
#include <mach/vm_page_size.h>
#include <unistd.h>
@@ -20,9 +20,6 @@ Allocator& allocator() {
}
void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<MTL::Buffer*>(ptr_)->contents();
}
@@ -32,11 +29,8 @@ namespace metal {
namespace {
BufferCache::BufferCache(ResidencySet& residency_set)
: head_(nullptr),
tail_(nullptr),
pool_size_(0),
residency_set_(residency_set) {}
BufferCache::BufferCache(MTL::Device* device)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool();
@@ -47,9 +41,6 @@ int BufferCache::clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) {
if (!holder->buf->heap()) {
residency_set_.erase(holder->buf);
}
holder->buf->release();
n_release++;
}
@@ -107,9 +98,6 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) {
total_bytes_freed += tail_->buf->length();
if (!tail_->buf->heap()) {
residency_set_.erase(tail_->buf);
}
tail_->buf->release();
tail_->buf = nullptr;
n_release++;
@@ -164,7 +152,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(residency_set_) {
buffer_cache_(device_) {
auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size =
@@ -201,19 +189,16 @@ size_t MetalAllocator::set_cache_limit(size_t limit) {
return limit;
};
size_t MetalAllocator::set_memory_limit(size_t limit) {
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
std::unique_lock lk(mutex_);
std::swap(limit, block_limit_);
relaxed_ = relaxed;
gc_limit_ = std::min(
block_limit_,
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
return limit;
};
size_t MetalAllocator::get_memory_limit() {
return block_limit_;
}
size_t MetalAllocator::set_wired_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, wired_limit_);
@@ -221,7 +206,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) {
return limit;
};
Buffer MetalAllocator::malloc(size_t size) {
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Metal doesn't like empty buffers
if (size == 0) {
return Buffer{nullptr};
@@ -248,6 +233,11 @@ Buffer MetalAllocator::malloc(size_t size) {
if (!buf) {
size_t mem_required = get_active_memory() + get_cache_memory() + size;
// If there is too much memory pressure, fail (likely causes a wait).
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
return Buffer{nullptr};
}
auto pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size,
@@ -271,13 +261,9 @@ Buffer MetalAllocator::malloc(size_t size) {
if (!buf) {
buf = device_->newBuffer(size, resource_options);
}
if (!buf) {
return Buffer{nullptr};
}
lk.lock();
num_resources_++;
if (!buf->heap()) {
residency_set_.insert(buf);
if (buf) {
num_resources_++;
}
}
@@ -291,6 +277,10 @@ Buffer MetalAllocator::malloc(size_t size) {
get_cache_memory() - max_pool_size_);
}
if (!buf->heap()) {
residency_set_.insert(buf);
}
return Buffer{static_cast<void*>(buf)};
}
@@ -306,14 +296,14 @@ void MetalAllocator::free(Buffer buffer) {
return;
}
std::unique_lock lk(mutex_);
if (!buf->heap()) {
residency_set_.erase(buf);
}
active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
num_resources_--;
if (!buf->heap()) {
residency_set_.erase(buf);
}
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();
@@ -332,40 +322,37 @@ MetalAllocator& allocator() {
return *allocator_;
}
} // namespace metal
size_t set_cache_limit(size_t limit) {
return metal::allocator().set_cache_limit(limit);
return allocator().set_cache_limit(limit);
}
size_t set_memory_limit(size_t limit) {
return metal::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return metal::allocator().get_memory_limit();
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed);
}
size_t set_wired_limit(size_t limit) {
if (limit > std::get<size_t>(metal::device_info().at(
"max_recommended_working_set_size"))) {
if (limit >
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");
}
return metal::allocator().set_wired_limit(limit);
return allocator().set_wired_limit(limit);
}
size_t get_active_memory() {
return metal::allocator().get_active_memory();
return allocator().get_active_memory();
}
size_t get_peak_memory() {
return metal::allocator().get_peak_memory();
return allocator().get_peak_memory();
}
void reset_peak_memory() {
metal::allocator().reset_peak_memory();
allocator().reset_peak_memory();
}
size_t get_cache_memory() {
return metal::allocator().get_cache_memory();
return allocator().get_cache_memory();
}
void clear_cache() {
return metal::allocator().clear_cache();
return allocator().clear_cache();
}
} // namespace metal
} // namespace mlx::core

View File

@@ -18,7 +18,7 @@ namespace {
class BufferCache {
public:
BufferCache(ResidencySet& residency_set);
BufferCache(MTL::Device* device);
~BufferCache();
MTL::Buffer* reuse_from_cache(size_t size);
@@ -42,11 +42,13 @@ class BufferCache {
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
MTL::Heap* heap_{nullptr};
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
BufferHolder* tail_;
size_t pool_size_;
ResidencySet& residency_set_;
};
} // namespace
@@ -54,7 +56,7 @@ class BufferCache {
class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */
public:
virtual Buffer malloc(size_t size) override;
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() {
@@ -71,8 +73,7 @@ class MetalAllocator : public allocator::Allocator {
return buffer_cache_.cache_size();
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit);
size_t get_memory_limit();
size_t set_memory_limit(size_t limit, bool relaxed);
size_t set_wired_limit(size_t limit);
void clear_cache();

View File

@@ -90,7 +90,7 @@ void binary_op_gpu_inplace(
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > UINT32_MAX;
work_per_thread = get_work_per_thread(a.dtype());
work_per_thread = 1;
}
std::string kernel_name =
get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
@@ -137,20 +137,13 @@ void binary_op_gpu_inplace(
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
size_t nthreads = out.data_size();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), arg_idx++);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), arg_idx++);
grid_dims = MTL::Size(nthreads, 1, 1);
}
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}

View File

@@ -64,7 +64,6 @@ inline void build_kernel(
cnt++);
}
std::string idx_type = use_big_index ? "int64_t" : "uint";
if (add_indices) {
os += fmt::format(
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
@@ -84,9 +83,6 @@ inline void build_kernel(
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
} else {
os += fmt::format(
" constant const {0}& size [[buffer({1})]],\n", idx_type, cnt++);
}
if (dynamic_dims) {
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
@@ -96,14 +92,13 @@ inline void build_kernel(
os += " uint3 pos [[thread_position_in_grid]],\n";
os += " uint3 grid [[threads_per_grid]]) {\n";
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
std::string idx_type = use_big_index ? "int64_t" : "uint";
if (contiguous && use_big_index) {
// This is only used for contiguous kernels which don't have
// a third grid dimension
os += " int64_t index = N_ * (pos.x + grid.x * int64_t(pos.y));\n";
} else if (contiguous) {
os += " uint index = N_ * pos.x;\n";
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
} else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format(
" int xshape = output_shape[{0}];\n",
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
@@ -115,9 +110,6 @@ inline void build_kernel(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
}
if (work_per_thread > 1 && contiguous) {
os += " for (int i = 0; i < N_ && index < size; ++i) {\n";
}
// Read constant / contiguous inputs in tmps
std::vector<array> nc_inputs;
@@ -201,7 +193,7 @@ inline void build_kernel(
}
// Open per-thread loop
if (work_per_thread > 1 && !contiguous) {
if (work_per_thread > 1) {
os +=
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
@@ -280,7 +272,6 @@ void Compiled::eval_gpu(
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_, [&]() {
int work_per_thread = get_work_per_thread(outputs_[0].dtype());
std::string kernel = metal::utils();
concatenate(
kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
@@ -293,9 +284,7 @@ void Compiled::eval_gpu(
constant_ids_,
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ work_per_thread);
/* dynamic_dims = */ false);
build_kernel(
kernel,
kernel_lib_ + "_contiguous_large",
@@ -306,8 +295,7 @@ void Compiled::eval_gpu(
/* contiguous = */ true,
/* ndim = */ 0,
/* dynamic_dims = */ false,
/* use_big_index = */ true,
/* work_per_thread = */ work_per_thread);
/* use_big_index = */ true);
for (int i = 1; i < 8; i++) {
build_kernel(
kernel,
@@ -480,13 +468,6 @@ void Compiled::eval_gpu(
if (!contiguous) {
compute_encoder.set_vector_bytes(strides[0], cnt++);
compute_encoder.set_vector_bytes(shape, cnt++);
} else {
auto size = outputs[0].data_size();
if (large) {
compute_encoder.set_bytes<int64_t>(size, cnt++);
} else {
compute_encoder.set_bytes<int>(size, cnt++);
}
}
// Put the number of dims in if it is dynamic
@@ -496,13 +477,12 @@ void Compiled::eval_gpu(
// Launch the kernel
if (contiguous) {
int work_per_thread = get_work_per_thread(outputs[0].dtype());
size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread);
size_t nthreads = outputs[0].data_size();
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
MTL::Size grid_dims = large
? get_2d_grid_dims(
outputs[0].shape(), outputs[0].strides(), work_per_thread)
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {

View File

@@ -5,7 +5,7 @@
#include <numeric>
#include <sstream>
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"
@@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu(
Shape unfolded_shape{implicit_M, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
@@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu(
// Prepare unfolding array
Shape unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
@@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
// Do filter transform
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
copies_w.push_back(filt_wg);
{
int bc = 32;
@@ -640,7 +640,7 @@ void winograd_conv_2D_gpu(
// Do input transform
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
copies_w.push_back(inp_wg);
{
int bc = 32;
@@ -667,7 +667,7 @@ void winograd_conv_2D_gpu(
// Do batched gemm
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
out_wg.set_data(allocator::malloc(out_wg.nbytes()));
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
copies_w.push_back(out_wg);
{
std::vector<array> empty_copies;
@@ -712,65 +712,6 @@ void winograd_conv_2D_gpu(
}
}
void depthwise_conv_2D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
std::ostringstream kname;
kname << "depthwise_conv_2d_" << type_to_name(out);
std::string base_name = kname.str();
const int N = conv_params.N;
const int ker_h = conv_params.wS[0];
const int ker_w = conv_params.wS[1];
const int str_h = conv_params.str[0];
const int str_w = conv_params.str[1];
const int tc = 8;
const int tw = 8;
const int th = 4;
const bool do_flip = conv_params.flip;
metal::MTLFCList func_consts = {
{&ker_h, MTL::DataType::DataTypeInt, 00},
{&ker_w, MTL::DataType::DataTypeInt, 01},
{&str_h, MTL::DataType::DataTypeInt, 10},
{&str_w, MTL::DataType::DataTypeInt, 11},
{&th, MTL::DataType::DataTypeInt, 100},
{&tw, MTL::DataType::DataTypeInt, 101},
{&do_flip, MTL::DataType::DataTypeBool, 200},
};
// clang-format off
kname << "_ker_h_" << ker_h
<< "_ker_w_" << ker_w
<< "_str_h_" << str_h
<< "_str_w_" << str_w
<< "_tgp_h_" << th
<< "_tgp_w_" << tw
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
MTL::Size group_dims = MTL::Size(tc, tw, th);
MTL::Size grid_dims = MTL::Size(
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void conv_2D_gpu(
const Stream& s,
metal::Device& d,
@@ -813,20 +754,11 @@ void conv_2D_gpu(
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
if (is_idil_one && groups > 1) {
if (groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
conv_params.wt_strides[1] == conv_params.wS[1] &&
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
@@ -923,7 +855,7 @@ void conv_3D_gpu(
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);

View File

@@ -1,15 +1,35 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/gpu/copy.h"
#include <sstream>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
void copy_gpu_inplace(
const array& in,
array& out,
@@ -84,8 +104,6 @@ void copy_gpu_inplace(
"[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy");
}
}
} else {
work_per_thread = get_work_per_thread(in.dtype());
}
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out)
@@ -147,28 +165,44 @@ void copy_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatch_threads(grid_dims, group_dims);
} else {
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
size_t nthreads = out.data_size();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const Strides& i_strides,
int64_t i_offset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
}
void fill_gpu(const array& val, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
bool large = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
@@ -180,21 +214,14 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1);
int work_per_thread = get_work_per_thread(val.dtype());
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = ceildiv(out.data_size(), work_per_thread);
size_t nthreads = out.data_size();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims;
if (large) {
compute_encoder.set_bytes<int64_t>(out.data_size(), 2);
grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread);
} else {
compute_encoder.set_bytes<int>(out.data_size(), 2);
grid_dims = MTL::Size(nthreads, 1, 1);
}
MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}

View File

@@ -5,8 +5,6 @@
#include "mlx/backend/common/copy.h"
#include "mlx/stream.h"
#include <optional>
namespace mlx::core {
// Generic copy inplace

View File

@@ -1,6 +1,6 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
@@ -19,7 +19,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
}

View File

@@ -1,24 +1,27 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <filesystem>
#include <sstream>
#include <sys/sysctl.h>
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
namespace {
// TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr const char* default_mtllib_path = METAL_PATH;
auto get_metal_version() {
@@ -55,10 +58,7 @@ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
}
#ifdef SWIFTPM_BUNDLE
MTL::Library* try_load_bundle(
MTL::Device* device,
NS::URL* url,
const std::string& lib_name) {
MTL::Library* try_load_bundle(MTL::Device* device, NS::URL* url) {
std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
SWIFTPM_BUNDLE + ".bundle";
auto bundle = NS::Bundle::alloc()->init(
@@ -66,7 +66,7 @@ MTL::Library* try_load_bundle(
if (bundle != nullptr) {
std::string resource_path =
std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
lib_name + ".metallib";
"default.metallib";
auto [lib, error] = load_library_from_path(device, resource_path.c_str());
if (lib) {
return lib;
@@ -76,133 +76,51 @@ MTL::Library* try_load_bundle(
}
#endif
// Firstly, search for the metallib in the same path as this binary
std::pair<MTL::Library*, NS::Error*> load_colocated_library(
MTL::Device* device,
const std::string& relative_path) {
std::string binary_dir = get_binary_directory();
if (binary_dir.size() == 0) {
return {nullptr, nullptr};
}
auto path = fs::path(binary_dir) / relative_path;
if (!path.has_extension()) {
path.replace_extension(".metallib");
}
return load_library_from_path(device, path.c_str());
}
std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
MTL::Device* device,
const std::string& lib_name) {
#ifdef SWIFTPM_BUNDLE
MTL::Library* library =
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);
if (library != nullptr) {
return {library, nullptr};
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL(), lib_name);
if (library != nullptr) {
return {library, nullptr};
}
}
#endif
return {nullptr, nullptr};
}
MTL::Library* load_default_library(MTL::Device* device) {
NS::Error* error[4];
MTL::Library* lib;
// First try the colocated mlx.metallib
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
if (lib) {
return lib;
}
std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx");
if (lib) {
return lib;
}
// Then try default.metallib in a SwiftPM bundle if we have one
std::tie(lib, error[2]) = load_swiftpm_library(device, "default");
if (lib) {
return lib;
}
// Finally try default_mtllib_path
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the default metallib. ";
for (int i = 0; i < 4; i++) {
if (error[i] != nullptr) {
msg << error[i]->localizedDescription()->utf8String() << " ";
}
}
throw std::runtime_error(msg.str());
}
return lib;
}
MTL::Library* load_library(
MTL::Device* device,
const std::string& lib_name,
const std::string& lib_path) {
// We have been given a path that ends in metallib so try to load it
if (lib_path.size() > 9 &&
std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) {
auto [lib, error] = load_library_from_path(device, lib_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << lib_path << "> with error "
<< error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
return lib;
}
// We have been given a path so try to load from lib_path / lib_name.metallib
if (lib_path.size() > 0) {
std::string full_path = lib_path + "/" + lib_name + ".metallib";
auto [lib, error] = load_library_from_path(device, full_path.c_str());
if (!lib) {
std::ostringstream msg;
msg << "Failed to load the metallib from <" << full_path
<< "> with error " << error->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
return lib;
}
// Try to load the colocated library
{
auto [lib, error] = load_colocated_library(device, lib_name);
const std::string& lib_name = "mlx",
const char* lib_path = default_mtllib_path) {
// Firstly, search for the metallib in the same path as this binary
std::string first_path = get_colocated_mtllib_path(lib_name);
if (first_path.size() != 0) {
auto [lib, error] = load_library_from_path(device, first_path.c_str());
if (lib) {
return lib;
}
}
// Try to load the library from swiftpm
{
auto [lib, error] = load_swiftpm_library(device, lib_name);
if (lib) {
return lib;
}
}
std::ostringstream msg;
msg << "Failed to load the metallib " << lib_name << ".metallib. "
<< "We attempted to load it from <" << get_binary_directory() << "/"
<< lib_name << ".metallib" << ">";
#ifdef SWIFTPM_BUNDLE
msg << " and from the Swift PM bundle.";
// try to load from a swiftpm resource bundle -- scan the available bundles to
// find one that contains the named bundle
{
MTL::Library* library =
try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL());
if (library != nullptr) {
return library;
}
auto bundles = NS::Bundle::allBundles();
for (int i = 0, c = (int)bundles->count(); i < c; i++) {
auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
library = try_load_bundle(device, bundle->resourceURL());
if (library != nullptr) {
return library;
}
}
}
#endif
throw std::runtime_error(msg.str());
// Couldn't find it so let's load it from default_mtllib_path
{
auto [lib, error] = load_library_from_path(device, lib_path);
if (!lib) {
std::ostringstream msg;
msg << error->localizedDescription()->utf8String() << "\n"
<< "Failed to load device library from <" << lib_path << ">"
<< " or <" << first_path << ">.";
throw std::runtime_error(msg.str());
}
return lib;
}
}
} // namespace
@@ -295,7 +213,7 @@ void CommandEncoder::barrier() {
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_default_library(device_)}};
library_map_ = {{"mlx", load_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back();
switch (arch) {
@@ -338,7 +256,7 @@ Device::~Device() {
void Device::new_queue(int index) {
auto thread_pool = metal::new_scoped_memory_pool();
auto q = device_->newCommandQueue();
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
debug_set_stream_queue_label(q, index);
if (!q) {
throw std::runtime_error(
@@ -769,4 +687,42 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
NS::AutoreleasePool::alloc()->init(), dtor);
}
void new_stream(Stream stream) {
if (stream.device == mlx::core::Device::gpu) {
device(stream.device).new_queue(stream.index);
}
}
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
auto init_device_info = []()
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto name = std::string(raw_device->name()->utf8String());
auto arch = std::string(raw_device->architecture()->name()->utf8String());
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
size_t rsrc_limit = 0;
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
if (rsrc_limit == 0) {
rsrc_limit = 499000;
}
return {
{"device_name", name},
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize},
{"resource_limit", rsrc_limit}};
};
static auto device_info_ = init_device_info();
return device_info_;
}
} // namespace mlx::core::metal

View File

@@ -21,14 +21,18 @@ namespace mlx::core::metal {
// Note, this function must be left inline in a header so that it is not
// dynamically linked.
inline std::string get_binary_directory() {
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string directory;
int success = dladdr((void*)get_binary_directory, &info);
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
directory = fs::path(info.dli_fname).remove_filename().c_str();
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return directory;
return mtllib_path;
}
using MTLFCList =
@@ -185,7 +189,15 @@ class Device {
void register_library(
const std::string& lib_name,
const std::string& lib_path = "");
const std::string& lib_path);
// Note, this should remain in the header so that it is not dynamically
// linked
void register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
MTL::Library* get_library(
const std::string& name,
@@ -266,6 +278,4 @@ class Device {
Device& device(mlx::core::Device);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
} // namespace mlx::core::metal

View File

@@ -4,7 +4,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/distributed/ops.h"

View File

@@ -1,102 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <memory>
#include "mlx/backend/gpu/available.h"
#include "mlx/backend/gpu/eval.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::gpu {
bool is_available() {
return true;
}
void new_stream(Stream stream) {
if (stream.device == mlx::core::Device::gpu) {
metal::device(stream.device).new_queue(stream.index);
}
}
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
msg << "[METAL] Command buffer execution failed: "
<< cbuf->error()->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
void eval(array& arr) {
auto pool = metal::new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
if (d.command_buffer_needs_commit(s.index)) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
}
void finalize(Stream s) {
auto pool = metal::new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index);
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
}
void synchronize(Stream s) {
auto pool = metal::new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
cb->retain();
d.end_encoding(s.index);
d.commit_command_buffer(s.index);
cb->waitUntilCompleted();
check_error(cb);
cb->release();
}
} // namespace mlx::core::gpu

View File

@@ -2,6 +2,7 @@
#include "mlx/event.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/scheduler.h"
namespace mlx::core {
@@ -23,6 +24,10 @@ void Event::wait() {
}
}
void Event::signal() {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
}
void Event::wait(Stream stream) {
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { wait(); });
@@ -37,9 +42,7 @@ void Event::wait(Stream stream) {
void Event::signal(Stream stream) {
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable {
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
});
scheduler::enqueue(stream, [*this]() mutable { signal(); });
} else {
auto& d = metal::device(stream.device);
d.end_encoding(stream.index);

View File

@@ -1,11 +1,39 @@
// Copyright © 2024 Apple Inc.
#include "mlx/fence.h"
#include <csignal>
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/fence.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core {
void signal_handler(int signum);
MTL::Buffer* signal_buffer() {
auto init = []() {
signal(SIGTERM, signal_handler);
auto dtor = [](void* buf) {
allocator::free(static_cast<MTL::Buffer*>(buf));
};
auto buf = std::shared_ptr<void>(
allocator::malloc_or_wait(sizeof(uint32_t)).ptr(), dtor);
static_cast<uint32_t*>(
static_cast<MTL::Buffer*>(buf.get())->contents())[0] = 0;
return buf;
};
static std::shared_ptr<void> buf = init();
return static_cast<MTL::Buffer*>(buf.get());
}
void signal_handler(int signum) {
auto buf = signal_buffer();
static_cast<std::atomic_uint*>(buf->contents())[0] = 1;
signal(signum, SIG_DFL);
raise(signum);
}
struct FenceImpl {
FenceImpl() {
auto d = metal::device(Device::gpu).mtl_device();
@@ -19,7 +47,7 @@ struct FenceImpl {
auto p = metal::new_scoped_memory_pool();
fence = static_cast<void*>(d->newSharedEvent());
} else {
auto buf = allocator::malloc(sizeof(uint32_t)).ptr();
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
fence = static_cast<void*>(buf);
cpu_value()[0] = 0;
}
@@ -93,6 +121,7 @@ void Fence::wait(Stream stream, const array& x) {
auto buf = static_cast<MTL::Buffer*>(f.fence);
compute_encoder.set_buffer(buf, 0);
compute_encoder.set_bytes(f.count, 1);
compute_encoder.set_buffer(signal_buffer(), 2);
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
d.get_command_buffer(idx)->addCompletedHandler(
@@ -138,7 +167,7 @@ void Fence::update(Stream stream, const array& x) {
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_bytes(nthreads, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(group_dims, grid_dims);
// Barrier on previous kernels
compute_encoder.barrier();

View File

@@ -7,10 +7,10 @@
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
#include "mlx/backend/metal/binary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/unary.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
@@ -281,7 +281,7 @@ std::tuple<array, array, array> compute_raders_constants(
}
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes()));
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
auto b_q_fft_ptr =
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
std::ptrdiff_t item_size = b_q_fft.itemsize();
@@ -327,11 +327,11 @@ std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
}
array w_k({n}, complex64, nullptr, {});
w_k.set_data(allocator::malloc(w_k.nbytes()));
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
array w_q({bluestein_n}, complex64, nullptr, {});
w_q.set_data(allocator::malloc(w_q.nbytes()));
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
auto w_q_ptr =
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
@@ -356,14 +356,20 @@ void multi_upload_bluestein_fft(
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array>& copies,
std::vector<array> copies,
const Stream& s) {
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
// algorithm
int n = inverse ? out.shape(axis) : in.shape(axis);
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
copies.push_back(w_k);
copies.push_back(w_q);
// Broadcast w_q and w_k to the batch size
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
auto temp_shape = inverse ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {});
@@ -372,13 +378,13 @@ void multi_upload_bluestein_fft(
if (real && !inverse) {
// Convert float32->complex64
copy_gpu(in, temp, CopyType::General, s);
copies.push_back(temp);
} else if (real && inverse) {
int back_offset = n % 2 == 0 ? 2 : 1;
auto slice_shape = in.shape();
slice_shape[axis] -= back_offset;
array slice_temp(slice_shape, complex64, nullptr, {});
array conj_temp(in.shape(), complex64, nullptr, {});
copies.push_back(slice_temp);
copies.push_back(conj_temp);
Shape rstarts(in.ndim(), 0);
@@ -388,28 +394,19 @@ void multi_upload_bluestein_fft(
unary_op_gpu({in}, conj_temp, "Conjugate", s);
slice_gpu(in, slice_temp, rstarts, rstrides, s);
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
copies.push_back(temp);
} else if (inverse) {
unary_op_gpu({in}, temp, "Conjugate", s);
copies.push_back(temp);
} else {
temp.copy_shared_buffer(in);
}
Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast(temp.shape(), complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
std::vector<std::pair<int, int>> pads;
auto padded_shape = out.shape();
padded_shape[axis] = plan.bluestein_n;
array pad_temp(padded_shape, complex64, nullptr, {});
auto zero = array(complex64_t{0.0f, 0.0f});
copies.push_back(zero);
pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s);
copies.push_back(pad_temp);
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
array pad_temp1(padded_shape, complex64, nullptr, {});
fft_op(
@@ -421,10 +418,7 @@ void multi_upload_bluestein_fft(
FourStepParams(),
/*inplace=*/false,
s);
copies.push_back(pad_temp1);
array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {});
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
fft_op(
@@ -441,11 +435,9 @@ void multi_upload_bluestein_fft(
Shape starts(in.ndim(), 0);
Shape strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
array temp2(temp_shape, complex64, nullptr, {});
slice_gpu(pad_temp1, temp2, starts, strides, s);
binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
if (real && !inverse) {
Shape rstarts(in.ndim(), 0);
@@ -457,21 +449,26 @@ void multi_upload_bluestein_fft(
array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float);
copies.push_back(inv_n);
copies.push_back(temp1);
copy_gpu(temp1, temp_float, CopyType::General, s);
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
} else if (inverse) {
auto inv_n = array({1.0f / n}, {1}, complex64);
array temp3(temp_shape, complex64, nullptr, {});
unary_op_gpu({temp1}, temp3, "Conjugate", s);
binary_op_gpu({temp3, inv_n}, out, "Multiply", s);
unary_op_gpu({temp1}, temp, "Conjugate", s);
binary_op_gpu({temp, inv_n}, out, "Multiply", s);
copies.push_back(inv_n);
copies.push_back(temp1);
copies.push_back(temp3);
} else {
out.copy_shared_buffer(temp1);
}
copies.push_back(w_k);
copies.push_back(w_q);
copies.push_back(w_k_broadcast);
copies.push_back(w_q_broadcast);
copies.push_back(temp);
copies.push_back(temp1);
copies.push_back(pad_temp);
copies.push_back(pad_temp1);
}
void four_step_fft(
@@ -481,9 +478,8 @@ void four_step_fft(
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array>& copies,
const Stream& s,
bool in_place) {
std::vector<array> copies,
const Stream& s) {
auto& d = metal::device(s.device);
if (plan.bluestein_n == -1) {
@@ -496,14 +492,7 @@ void four_step_fft(
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
four_step_params.first_step = false;
fft_op(
temp,
out,
axis,
inverse,
real,
four_step_params,
/*inplace=*/in_place,
s);
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
copies.push_back(temp);
} else {
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
@@ -562,7 +551,8 @@ void fft_op(
flags.row_contiguous = is_row_contiguous;
flags.contiguous = data_size == x_copy.size();
x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags);
x_copy.set_data(
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
copies.push_back(x_copy);
return x_copy;
@@ -585,7 +575,7 @@ void fft_op(
auto plan = plan_fft(n);
if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace);
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -593,7 +583,7 @@ void fft_op(
// TODO: allow donation here
if (!inplace) {
out.set_data(
allocator::malloc(out.nbytes()),
allocator::malloc_or_wait(out.nbytes()),
out_data_size,
out_strides,
in_contiguous.flags());

View File

@@ -1,9 +1,11 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/hadamard.h"
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/kernels.h"
@@ -13,6 +15,7 @@
namespace mlx::core {
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
std::string gen_hadamard_codelet(int m) {
// Generate a O(m^2) hadamard codelet for a given M
@@ -57,142 +60,121 @@ std::string gen_hadamard_codelet(int m) {
return source.str();
}
void hadamard_mn_contiguous(
const array& x,
array& y,
int m,
int n1,
int n2,
float scale,
metal::Device& d,
const Stream& s) {
int n = n1 * n2;
int read_width_n1 = n1 == 2 ? 2 : 4;
int read_width_n2 = n2 == 2 ? 2 : 4;
int read_width_m = (n == 2 || m == 28) ? 2 : 4;
int max_radix_1 = std::min(n1, 16);
int max_radix_2 = std::min(n2, 16);
float scale_n1 = 1.0;
float scale_n2 = (m == 1) ? scale : 1.0;
float scale_m = scale;
// n2 is a row contiguous power of 2 hadamard transform
MTL::Size group_dims_n2(n2 / max_radix_2, 1, 1);
MTL::Size grid_dims_n2(n2 / max_radix_2, x.size() / n2, 1);
// n1 is a strided power of 2 hadamard transform with stride n2
MTL::Size group_dims_n1(n1 / max_radix_1, 1, 1);
MTL::Size grid_dims_n1(n1 / max_radix_1, x.size() / n, n2);
// m is a strided hadamard transform with stride n = n1 * n2
MTL::Size group_dims_m(
std::min(n / read_width_m, MAX_HADAMARD_THREADS_PER_GROUP), 1, 1);
MTL::Size grid_dims_m(
group_dims_m.width, x.size() / m / read_width_m / group_dims_m.width, 1);
// Make the kernel
std::string kname;
kname.reserve(32);
concatenate(kname, "hadamard_", n * m, "_", type_to_name(x));
auto lib = d.get_library(kname, [&]() {
std::string kernel;
concatenate(
kernel,
metal::utils(),
gen_hadamard_codelet(m),
metal::hadamard(),
get_template_definition(
"n2" + kname,
"hadamard_n",
get_type_string(x.dtype()),
n2,
max_radix_2,
read_width_n2));
if (n1 > 1) {
kernel += get_template_definition(
"n1" + kname,
"hadamard_n",
get_type_string(x.dtype()),
n1,
max_radix_1,
read_width_n1,
n2);
}
if (m > 1) {
kernel += get_template_definition(
"m" + kname,
"hadamard_m",
get_type_string(x.dtype()),
n,
m,
read_width_m);
}
return kernel;
});
// Launch the strided transform for n1
if (n1 > 1) {
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel("n1" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_n1, 2);
compute_encoder.dispatch_threads(grid_dims_n1, group_dims_n1);
}
// Launch the transform for n2
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel("n2" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(n1 > 1 ? y : x, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_n2, 2);
compute_encoder.dispatch_threads(grid_dims_n2, group_dims_n2);
// Launch the strided transform for m
if (m > 1) {
auto kernel = d.get_kernel("m" + kname, lib);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(y, 0);
compute_encoder.set_output_array(y, 1);
compute_encoder.set_bytes(scale_m, 2);
compute_encoder.dispatch_threads(grid_dims_m, group_dims_m);
}
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
// Split the hadamard transform so that all of them work on vectors smaller
// than 8192 elements.
//
// We decompose it in the following way:
//
// n = m * n1 * n2 = m * 2^k1 * 2^k2
//
// where m is in (1, 12, 20, 28) and n1 and n2 <= 8192
auto [n, m] = decompose_hadamard(in.shape().back());
int n1 = 1, n2 = n;
if (n > 8192) {
for (n2 = 2; n2 * n2 < n; n2 *= 2) {
std::vector<array> copies;
// Only support the last axis for now
int axis = in.ndim() - 1;
auto check_input = [&copies, &s](const array& x) {
// TODO(alexbarron) pass strides to kernel to relax this constraint
bool no_copy = x.flags().row_contiguous;
if (no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
n1 = n / n2;
};
const array& in_contiguous = check_input(in);
if (in_contiguous.is_donatable()) {
out.copy_shared_buffer(in_contiguous);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
if (in.flags().row_contiguous) {
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
hadamard_mn_contiguous(in, out, m, n1, n2, scale_, d, s);
} else {
copy_gpu(in, out, CopyType::General, s);
hadamard_mn_contiguous(out, out, m, n1, n2, scale_, d, s);
int n, m;
std::tie(n, m) = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
}
int max_radix = std::min(n, 16);
// Use read_width 2 for m = 28 to avoid register spilling
int read_width = (n == 2 || m == 28) ? 2 : 4;
std::ostringstream kname;
kname << "hadamard_" << n * m << "_" << type_to_name(out);
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
kernel_source << get_template_definition(
"n" + kernel_name,
"hadamard_n",
get_type_string(in.dtype()),
n,
max_radix,
read_width);
kernel_source << get_template_definition(
"m" + kernel_name,
"hadamard_m",
get_type_string(in.dtype()),
n,
m,
read_width);
return kernel_source.str();
});
int batch_size = in.size() / n;
int threads_per = n / max_radix;
auto& compute_encoder = d.get_command_encoder(s.index);
auto launch_hadamard = [&](const array& in,
array& out,
const std::string& kernel_name,
float scale) {
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(scale, 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
};
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
//
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
//
// y = h48 @ x
//
// Upload 1:
// tmp = a.reshape(12, 4) @ h4
//
// Upload 2:
// y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(temp, out, "m" + kernel_name, scale_);
} else {
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More