Compare commits

..

9 Commits

Author SHA1 Message Date
Angelos Katharopoulos
4515866024 Change the linux test to ubuntu 24.04 2025-01-20 22:58:05 -08:00
Angelos Katharopoulos
6fe2b82926 Add gcc 13 to the linux build 2025-01-17 18:00:24 -08:00
Angelos Katharopoulos
c75b5e9d19 Do not build wheels every time 2025-01-17 17:42:11 -08:00
Angelos Katharopoulos
6f12eda549 Test for 13.5 as well 2025-01-17 17:39:20 -08:00
Angelos Katharopoulos
a541fe9312 Set deployment target 13.5 2025-01-17 17:16:36 -08:00
Angelos Katharopoulos
2bdd20f257 Test XCode 16 2025-01-17 14:14:45 -08:00
Angelos Katharopoulos
aa7b9688ce Move some kernels to get_template_definition 2025-01-17 13:24:17 -08:00
Angelos Katharopoulos
0a41393dba Replace fmt::format with std::format 2025-01-17 11:24:16 -08:00
Angelos Katharopoulos
e300a01f4a Testing C++ 20 2025-01-16 18:24:16 -08:00
202 changed files with 6023 additions and 10321 deletions

View File

@@ -24,7 +24,7 @@ jobs:
type: boolean
default: false
macos:
xcode: "15.2.0"
xcode: "16.0.0"
resource_class: macos.m1.medium.gen1
steps:
- checkout
@@ -70,8 +70,8 @@ jobs:
git push -f origin gh-pages
linux_build_and_test:
docker:
- image: cimg/python:3.9
machine:
image: ubuntu-2404:2024.11.1
steps:
- checkout
@@ -84,30 +84,33 @@ jobs:
- run:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install numpy
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get update -y
sudo apt-get install -y python3.9 python3.9-distutils python3.9-dev
python3.9 -m pip install --upgrade cmake
python3.9 -m pip install nanobind==2.4.0
python3.9 -m pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libopenblas-dev liblapacke-dev openmpi-bin libopenmpi-dev
- run:
name: Install Python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF -DPython_EXECUTABLE=/usr/bin/python3.9" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
python3.9 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF -DPython_EXECUTABLE=/usr/bin/python3.9" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop
python3.9 setup.py develop
- run:
name: Generate package stubs
command: |
echo "stubs"
pip install typing_extensions
python setup.py generate_stubs
python3.9 -m pip install typing_extensions
python3.9 setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
python3.9 -m unittest discover python/tests -v
- run:
name: Build CPP only
command: |
@@ -122,7 +125,10 @@ jobs:
parameters:
xcode_version:
type: string
default: "15.2.0"
default: "16.0.0"
deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1
@@ -146,7 +152,9 @@ jobs:
name: Install Python package
command: |
source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS=-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >> \
pip install -e . -v
- run:
name: Generate package stubs
command: |
@@ -160,7 +168,6 @@ jobs:
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run:
name: Build example extension
command: |
@@ -174,7 +181,11 @@ jobs:
name: Build CPP only
command: |
source env/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
mkdir -p build
cd build/
cmake .. \
-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >>
make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
command: |
@@ -189,14 +200,15 @@ jobs:
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
-DMLX_METAL_JIT=ON \
-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >>
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
CMAKE_ARGS="-DMLX_METAL_JIT=ON -DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >>" \
pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
@@ -209,7 +221,10 @@ jobs:
default: "3.9"
xcode_version:
type: string
default: "15.2.0"
default: "16.0.0"
deployment_target:
type: string
default: ""
build_env:
type: string
default: ""
@@ -238,6 +253,7 @@ jobs:
source env/bin/activate
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS=-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >> \
pip install . -v
- run:
name: Generate package stubs
@@ -251,6 +267,7 @@ jobs:
source env/bin/activate
<< parameters.build_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
CMAKE_ARGS=-DCMAKE_OSX_DEPLOYMENT_TARGET=<< parameters.deployment_target >> \
python -m build -w
- when:
condition: << parameters.build_env >>
@@ -331,9 +348,10 @@ workflows:
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
xcode_version: ["16.0.0"]
deployment_target: ["", "13.5"]
- linux_build_and_test
- build_documentation
- build_documentation
build_pypi_release:
when:
@@ -351,7 +369,8 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
xcode_version: ["16.0.0"]
deployment_target: ["", "13.5"]
build_env: ["PYPI_RELEASE=1"]
- build_documentation:
filters:
@@ -375,7 +394,8 @@ workflows:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
xcode_version: ["16.0.0"]
deployment_target: ["", "13.5"]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
@@ -388,7 +408,8 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
xcode_version: ["16.0.0"]
deployment_target: ["", "13.5"]
weekly_build:
when:
and:
@@ -399,7 +420,8 @@ workflows:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
xcode_version: ["16.0.0"]
deployment_target: ["", "13.5"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:

View File

@@ -1,16 +1,16 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7
rev: v19.1.4
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.1.0
rev: 24.10.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 6.0.0
rev: 5.13.2
hooks:
- id: isort
args:

View File

@@ -4,7 +4,7 @@ project(mlx LANGUAGES C CXX)
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
@@ -25,7 +25,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.23.0)
set(MLX_VERSION 0.22.0)
endif()
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
@@ -147,7 +147,6 @@ if(MLX_BUILD_CPU)
if(MLX_BUILD_ACCELERATE)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(MLX_USE_ACCELERATE)
add_compile_definitions(ACCELERATE_NEW_LAPACK)
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
# Download and build OpenBLAS from source code.
@@ -224,14 +223,6 @@ target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>)
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(

View File

@@ -8,23 +8,14 @@ L = 16384
H = 32
H_k = H // 4
D = 128
V = 128
dtype = mx.float16
loops = 10
def upproject(x, w):
if w is None:
return x
else:
return x @ w.T
def attention(q, k, v, mask=None, w=None):
def attention(q, k, v, mask=None):
def _sdpa(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
_, _, _, V = v.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
@@ -34,18 +25,16 @@ def attention(q, k, v, mask=None, w=None):
s = mx.where(m, s, mx.finfo(s.dtype).min)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, V)
return o.reshape(B, Hq, L, D)
for i in range(loops):
q = _sdpa(q, k, v)
q = upproject(q, w)
return q
def sdpa(q, k, v, mask=None, w=None):
def sdpa(q, k, v, mask=None):
for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
q = upproject(q, w)
return q
@@ -53,37 +42,34 @@ def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mx.eval(q, k, v, w)
time_fn(attention, q, k, v, w=w)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mx.eval(q, k, v)
time_fn(attention, q, k, v)
def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
mx.eval(q, k, v, w)
time_fn(sdpa, q, k, v, w=w)
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mx.eval(q, k, v)
time_fn(sdpa, q, k, v)
def time_self_attention_sdpa_with_mask():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mask = mx.full((L,), True)
mask[L // 2 :] = False
mx.eval(q, k, v, mask, w)
mx.eval(q, k, v, mask)
def sdpa_mask(*args):
return sdpa(*args, mask=mask, w=w)
return sdpa(*args, mask=mask)
def attention_mask(*args):
return attention(*args, mask=mask, w=w)
return attention(*args, mask=mask)
time_fn(attention_mask, q, k, v)
time_fn(sdpa_mask, q, k, v)

View File

@@ -1,55 +0,0 @@
import time
import mlx.core as mx
rank = mx.distributed.init().rank()
def timeit(fn, a):
# warmup
for _ in range(5):
mx.eval(fn(a))
its = 10
tic = time.perf_counter()
for _ in range(its):
mx.eval(fn(a))
toc = time.perf_counter()
ms = 1000 * (toc - tic) / its
return ms
def all_reduce_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_sum(x)
x = x - 1
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
def all_gather_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_gather(x)[0]
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All gather: time per iteration {ms:.6f} (ms)")
if __name__ == "__main__":
all_reduce_benchmark()
all_gather_benchmark()

View File

@@ -51,20 +51,11 @@ The default floating point type is ``float32`` and the default integer type is
* - ``float32``
- 4
- 32-bit float
* - ``float64``
- 4
- 64-bit double
* - ``complex64``
- 8
- 64-bit complex float
.. note::
Arrays with type ``float64`` only work with CPU operations. Using
``float64`` arrays on the GPU will result in an exception.
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category.

View File

@@ -5,8 +5,8 @@ Linear Algebra
.. currentmodule:: mlx.core.linalg
.. autosummary::
:toctree: _autosummary
.. autosummary::
:toctree: _autosummary
inv
tri_inv
@@ -18,7 +18,3 @@ Linear Algebra
svd
eigvalsh
eigh
lu
lu_factor
solve
solve_triangular

View File

@@ -32,7 +32,6 @@ Operations
atleast_2d
atleast_3d
bitwise_and
bitwise_invert
bitwise_or
bitwise_xor
block_masked_mm

View File

@@ -57,7 +57,7 @@ with the Anaconda package manager as follows:
.. code:: shell
$ conda install conda-forge::openmpi
$ conda install openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by

View File

@@ -21,13 +21,11 @@ Let's convert an array to NumPy and back.
.. note::
Since NumPy does not support ``bfloat16`` arrays, you will need to convert
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118
buffer format string does not match the dtype V item size 0.``
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
``np.array(a.astype(mx.float32))``.
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
By default, NumPy copies data to a new array. This can be prevented by creating
an array view:
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
.. code-block:: python
@@ -37,16 +35,10 @@ an array view:
a_view[0] = 1
print(a[0].item()) # 1
.. note::
A NumPy array view is a normal NumPy array, except that it does not own its memory.
This means writing to the view is reflected in the original array.
NumPy arrays with type ``float64`` will be default converted to MLX arrays
with type ``float32``.
A NumPy array view is a normal NumPy array, except that it does not own its
memory. This means writing to the view is reflected in the original array.
While this is quite powerful to prevent copying arrays, it should be noted that
external changes to the memory of arrays cannot be reflected in gradients.
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
Let's demonstrate this in an example:
@@ -64,12 +56,11 @@ Let's demonstrate this in an example:
The function ``f`` indirectly modifies the array ``x`` through a memory view.
However, this modification is not reflected in the gradient, as seen in the
last line outputting ``1.0``, representing the gradient of the sum operation
alone. The squaring of ``x`` occurs externally to MLX, meaning that no
gradient is incorporated. It's important to note that a similar issue arises
during array conversion and copying. For instance, a function defined as
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
representing the gradient of the sum operation alone.
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
It's important to note that a similar issue arises during array conversion and copying.
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
even though no in-place operations on MLX memory are executed.
PyTorch
@@ -80,8 +71,7 @@ PyTorch
PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now.
PyTorch supports the buffer protocol, but it requires an explicit
:obj:`memoryview`.
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
.. code-block:: python
@@ -92,8 +82,7 @@ PyTorch supports the buffer protocol, but it requires an explicit
b = torch.tensor(memoryview(a))
c = mx.array(b.numpy())
Conversion from PyTorch tensors back to arrays must be done via intermediate
NumPy arrays with ``numpy()``.
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
JAX
---
@@ -111,8 +100,7 @@ JAX fully supports the buffer protocol.
TensorFlow
----------
TensorFlow supports the buffer protocol, but it requires an explicit
:obj:`memoryview`.
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
.. code-block:: python

View File

@@ -6,7 +6,6 @@
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/utils.h"
#include "axpby/axpby.h"

View File

@@ -29,16 +29,21 @@ if(WIN32)
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
if(MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cpu)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
elseif(MLX_BUILD_CPU)
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
endif()
if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)

View File

@@ -25,18 +25,7 @@ array::array(
std::move(shape),
dtype,
std::move(primitive),
std::move(inputs))) {
if (has_primitive() && this->primitive().stream().device == Device::gpu) {
for (auto& in : this->inputs()) {
if (in.dtype() == float64) {
throw std::invalid_argument("float64 is not supported on the GPU");
}
}
if (this->dtype() == float64) {
throw std::invalid_argument("float64 is not supported on the GPU");
}
}
}
std::move(inputs))) {}
std::vector<array> array::make_arrays(
std::vector<Shape> shapes,
@@ -76,18 +65,6 @@ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
set_data(data, deleter);
}
array::array(
allocator::Buffer data,
Shape shape,
Dtype dtype,
Strides strides,
size_t data_size,
Flags flags,
Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, data_size, std::move(strides), flags, deleter);
}
void array::detach() {
for (auto& s : array_desc_->siblings) {
s.array_desc_->inputs.clear();

View File

@@ -243,18 +243,6 @@ class array {
bool col_contiguous : 1;
};
/** Build an array from all the info held by the array description. Including
* the buffer, strides, flags.
*/
explicit array(
allocator::Buffer data,
Shape shape,
Dtype dtype,
Strides strides,
size_t data_size,
Flags flags,
Deleter deleter = allocator::free);
/** The array's primitive. */
Primitive& primitive() const {
return *(array_desc_->primitive);
@@ -606,9 +594,6 @@ void array::init(It src) {
case float32:
std::copy(src, src + size(), data<float>());
break;
case float64:
std::copy(src, src + size(), data<double>());
break;
case bfloat16:
std::copy(src, src + size(), data<bfloat16_t>());
break;

View File

@@ -0,0 +1,8 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)

View File

@@ -0,0 +1,20 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
// TODO: Add accelerate based optimizations for CPU conv
}
} // namespace mlx::core

View File

@@ -0,0 +1,253 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <Accelerate/Accelerate.h>
#include "mlx/backend/accelerate/utils.h"
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
std::tuple<bool, size_t, array> check_transpose(const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
}
inline void matmul_cblas_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_cblas_general(a_pre, b_pre, out);
}
inline void matmul_bnns_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
// TODO: Update to utilize BNNS broadcasting
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,
/* bool transA = */ a_transposed,
/* bool transB = */ b_transposed,
/* bool quadratic = */ false,
/* bool a_is_weights = */ false,
/* bool b_is_weights = */ false,
/* BNNSNDArrayDescriptor iA_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, lda, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
/* BNNSNDArrayDescriptor iB_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, ldb, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
/* BNNSNDArrayDescriptor o_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{N, M, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, N, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
};
auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
BNNSFilterApplyTwoInput(
bnns_filter,
a.data<uint8_t>() +
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
b.data<uint8_t>() +
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
out.data<uint8_t>() + M * N * i * out.itemsize());
}
BNNSFilterDestroy(bnns_filter);
}
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
// TODO: Update to utilize BNNS broadcasting
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_bnns_general(a_pre, b_pre, out);
}
template <typename T>
inline void mask_matrix(
T* data,
const bool* mask,
int tile_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str) {
int tX = (X + tile_size - 1) / tile_size;
int tY = (Y + tile_size - 1) / tile_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
if (!do_mask) {
int loc_x = i * tile_size;
int loc_y = j * tile_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
int size_x = std::min(tile_size, X - loc_x);
int size_y = std::min(tile_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
}
}
}
}
}
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() == float32) {
return matmul_cblas(inputs[0], inputs[1], out);
}
return matmul_bnns(inputs[0], inputs[1], out);
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy(c, out, ctype);
if (out.dtype() == float32) {
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
}
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core

View File

@@ -0,0 +1,603 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include <Accelerate/Accelerate.h>
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/unary.h"
#include "mlx/primitives.h"
#define DEFAULT(primitive) \
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
primitive::eval(inputs, out); \
}
#define DEFAULT_MULTI(primitive) \
void primitive::eval_cpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
primitive::eval(inputs, outputs); \
}
namespace mlx::core {
// Use the default implementation for the following primitives
DEFAULT(Arange)
DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(Broadcast)
DEFAULT(BroadcastAxes)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(ExpandDims)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT_MULTI(QRF)
DEFAULT(RandomBits)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(Squeeze)
DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else if (in.dtype() == int32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
} else {
eval(inputs, out);
}
}
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
a,
b,
out,
[](auto x, auto y) { return x + y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary_op<int>(
a,
b,
out,
[](auto x, auto y) { return x + y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
});
} else {
eval(inputs, out);
}
}
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvacosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvacoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvasinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvasinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvatanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
if (a.is_donatable()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
int size = a.data_size();
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvatanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().contiguous) {
// Use accelerate functions if possible
if (in.dtype() == float32 && out.dtype() == uint32) {
set_unary_output_data(in, out);
vDSP_vfixu32(
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == float32 && out.dtype() == int32) {
set_unary_output_data(in, out);
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == uint32 && out.dtype() == float32) {
set_unary_output_data(in, out);
vDSP_vfltu32(
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
} else if (in.dtype() == int32 && out.dtype() == float32) {
set_unary_output_data(in, out);
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
}
}
eval(inputs, out);
}
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvcosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvcoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == int32) {
binary_op<int>(
a,
b,
out,
[](auto x, auto y) { return x / y; },
UseDefaultBinaryOp(),
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
});
} else if (a.dtype() == float32) {
binary_op<float>(
a,
b,
out,
[](auto x, auto y) { return x / y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else {
eval(inputs, out);
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpm1f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
assert(in.dtype() == out.dtype());
if (in.data_size() == 1 && out.dtype() == float32) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
vDSP_vfill(in.data<float>(), out.data<float>(), 1, out.size());
} else {
eval(inputs, out);
}
}
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
switch (base_) {
case Base::e:
vvlogf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
break;
case Base::two:
vvlog2f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
break;
case Base::ten:
vvlog10f(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
break;
}
} else {
eval(inputs, out);
}
}
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else {
eval(inputs, out);
}
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
a,
b,
out,
[](auto x, auto y) { return x * y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else {
eval(inputs, out);
}
}
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else {
eval(inputs, out);
}
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
int size = a.size();
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (reduce_type_ == Scan::Sum && out.dtype() == float32 &&
in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int stride = in.shape(axis_);
int count = in.size() / stride;
const float* input = in.data<float>();
float* output = out.data<float>();
float s = 1.0;
if (!reverse_) {
for (int i = 0; i < count; i++) {
vDSP_vrsum(input - 1, 1, &s, output, 1, stride);
input += stride;
output += stride;
}
} else {
for (int i = 0; i < count; i++) {
input += stride - 1;
output += stride - 1;
vDSP_vrsum(input + 1, -1, &s, output, -1, stride);
input++;
output++;
}
}
} else {
eval(inputs, out);
}
}
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvsinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvsinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
eval(inputs, out);
}
}
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
if (recip_) {
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
} else {
vvsqrtf(out.data<float>(), in.data<float>(), &size);
}
} else {
eval(inputs, out);
}
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
a,
b,
out,
[](auto x, auto y) { return x - y; },
[](const auto* s, const auto* vec, auto* o, auto n) {
float minus_1 = -1;
vDSP_vsmsa(
(const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n);
},
[](const auto* vec, const auto* s, auto* o, auto n) {
float val = -(*s);
vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n);
},
[](const auto* a, const auto* b, auto* o, auto n) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary_op<int>(
a,
b,
out,
[](auto x, auto y) { return x - y; },
UseDefaultBinaryOp(),
[](const auto* vec, const auto* s, auto* o, auto n) {
int val = -(*s);
vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n);
},
UseDefaultBinaryOp());
} else {
eval(inputs, out);
}
}
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvtanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
vvtanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,117 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <simd/vector.h>
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void _qmm_t_4_64(
float* result,
const float* x,
const uint32_t* w,
const float* scales,
const float* biases,
int M,
int N,
int K,
int B,
bool batched_w) {
constexpr int bits = 4;
constexpr int group_size = 64;
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
int w_els = N * K / pack_factor;
int g_els = w_els * pack_factor / group_size;
for (int i = 0; i < B; i++) {
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
}
}
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
sum += (*x_local) * wf;
x_local++;
}
}
*result = simd_reduce_add(sum);
result++;
}
x += K;
}
if (batched_w) {
w += w_els;
scales += g_els;
biases += g_els;
}
}
}
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x = inputs[0];
auto& w = inputs[1];
auto& scales = inputs[2];
auto& biases = inputs[3];
bool condition =
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
scales.flags().row_contiguous && biases.flags().row_contiguous &&
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
if (condition) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int B = x.size() / K / M;
bool batched_w = w.ndim() > 2;
_qmm_t_4_64(
out.data<float>(),
x.data<float>(),
w.data<uint32_t>(),
scales.data<float>(),
biases.data<float>(),
M,
N,
K,
B,
batched_w);
} else {
eval(inputs, out);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,139 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, typename VT>
struct MinReduction {
T operator()(const T& a, const T& b) {
return std::min(a, b);
}
VT operator()(VT a, VT b) {
return simd_min(a, b);
}
};
template <typename T, typename VT>
struct MaxReduction {
T operator()(const T& a, const T& b) {
return std::max(a, b);
}
VT operator()(VT a, VT b) {
return simd_max(a, b);
}
};
template <typename T, typename VT>
struct SumReduction {
T operator()(const T& a, const T& b) {
return a + b;
}
VT operator()(VT a, VT b) {
return a + b;
}
};
template <typename T, typename VT, int N, typename Reduction>
struct StridedReduce {
void operator()(const T* x, T* accum, int size, size_t stride) {
Reduction op;
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = op((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = op(*a, *x);
a++;
x++;
}
}
}
};
} // namespace
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32) {
if (reduce_type_ == Reduce::Sum) {
reduction_op<float, float>(
in,
out,
axes_,
0,
StridedReduce<
float,
simd_float16,
16,
SumReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) {
float acc;
vDSP_sve((const float*)x, 1, &acc, size);
(*accum) += acc;
},
[](auto* accum, auto x) { *accum += x; });
return;
} else if (reduce_type_ == Reduce::Max) {
reduction_op<float, float>(
in,
out,
axes_,
-std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MaxReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) {
float max;
vDSP_maxv((const float*)x, 1, &max, size);
(*accum) = (*accum < max) ? max : *accum;
},
[](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; });
return;
} else if (reduce_type_ == Reduce::Min) {
reduction_op<float, float>(
in,
out,
axes_,
std::numeric_limits<float>::infinity(),
StridedReduce<
float,
simd_float16,
16,
MinReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) {
float min;
vDSP_minv((const float*)x, 1, &min, size);
(*accum) = (*accum > min) ? min : *accum;
},
[](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; });
return;
}
}
// TODO: Add integer addition and min/max using the templates above and
// simd_int16 and friends.
eval(inputs, out);
}
} // namespace mlx::core

View File

@@ -0,0 +1,393 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <limits>
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <arm_neon.h>
#endif
#include <simd/math.h>
#include <simd/vector.h>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
/**
* Compute exp(x) in an optimizer friendly way as follows:
*
* First change the problem to computing 2**y where y = x / ln(2).
*
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
* `ipart` and y2 is fractional part. For the integer part we perform bit
* shifting and for the fractional part we use a polynomial approximation.
*
* The algorithm and constants of the polynomial taken from
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
* from Cephes math library.
*
* Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0.
*/
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
auto x = x_init * 1.442695; // multiply with log_2(e)
simd_float16 ipart, fpart;
simd_int16 epart;
x = simd_clamp(x, -80, 80);
ipart = simd::floor(x + 0.5);
fpart = x - ipart;
x = 1.535336188319500e-4f;
x = x * fpart + 1.339887440266574e-3f;
x = x * fpart + 9.618437357674640e-3f;
x = x * fpart + 5.550332471162809e-2f;
x = x * fpart + 2.402264791363012e-1f;
x = x * fpart + 6.931472028550421e-1f;
x = x * fpart + 1.000000000000000f;
// generate 2**ipart in the floating point representation using integer
// bitshifting
epart = (simd_int(ipart) + 127) << 23;
// Avoid supressing NaNs
simd_int16 eq = (x_init == x_init);
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/**
* The ARM neon equivalent of the fast exp above.
*/
inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
float16x8_t fpart = vsubq_f16(x, ipart);
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
// generate 2**ipart in the floating point representation using integer
// bitshifting
int16x8_t epart = vcvtq_s16_f16(ipart);
epart = vaddq_s16(epart, vdupq_n_s16(15));
epart = vshlq_n_s16(epart, 10);
return vmulq_f16(vreinterpretq_f16_s16(epart), x);
}
/**
* Implementation of folding maximum for ARM neon. This should possibly be
* refactored out of softmax.cpp at some point.
*/
inline float16_t neon_reduce_max(float16x8_t x) {
float16x4_t y;
y = vpmax_f16(vget_low_f16(x), vget_high_f16(x));
y = vpmax_f16(y, y);
y = vpmax_f16(y, y);
return vget_lane_f16(y, 0);
}
/**
* Implementation of folding sum for ARM neon. This should possibly be
* refactored out of softmax.cpp at some point.
*/
inline float16_t neon_reduce_add(float16x8_t x) {
float16x4_t y;
float16x4_t zero = vdup_n_f16(0);
y = vpadd_f16(vget_low_f16(x), vget_high_f16(x));
y = vpadd_f16(y, zero);
y = vpadd_f16(y, zero);
return vget_lane_f16(y, 0);
}
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
return vdupq_n_f16(a);
}
VT load(const T* a) {
return vld1q_f16(a);
}
void store(T* dst, VT x) {
vst1q_f16(dst, x);
}
VT max(VT a, VT b) {
return vmaxq_f16(a, b);
}
VT exp(VT x) {
return neon_fast_exp(x);
}
VT add(VT a, VT b) {
return vaddq_f16(a, b);
}
VT sub(VT a, T b) {
return vsubq_f16(a, vdupq_n_f16(b));
}
VT mul(VT a, VT b) {
return vmulq_f16(a, b);
}
VT mul(VT a, T b) {
return vmulq_f16(a, vdupq_n_f16(b));
}
T reduce_max(VT x) {
return neon_reduce_max(x);
}
T reduce_add(VT x) {
return neon_reduce_add(x);
}
};
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
}
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
size_t s = M;
while (s >= N) {
VT vals;
if constexpr (std::is_same<T, AccT>::value) {
vals = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vals[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vmaximum = ops.max(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = ops.reduce_max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
VT vnormalizer = ops.init(0.0);
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
VT vexp;
if constexpr (std::is_same<T, AccT>::value) {
vexp = ops.load(current_in_ptr);
} else {
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
}
vexp = ops.exp(ops.sub(vexp, maximum));
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, vexp);
}
vnormalizer = ops.add(vnormalizer, vexp);
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = ops.reduce_add(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if (std::is_same<T, AccT>::value) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (std::is_same<T, AccT>::value) {
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
} else {
VT vexp;
for (int i = 0; i < N; ++i) {
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
}
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
for (int i = 0; i < N; ++i) {
current_out_ptr[i] = vexp[i];
}
current_in_ptr += N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
current_in_ptr++;
}
current_out_ptr++;
}
}
}
} // namespace
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto check_input = [](array x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General);
return x_copy;
}
};
array in = check_input(std::move(inputs[0]));
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::invalid_argument(
"Softmax is defined only for floating point types");
break;
case float32:
softmax<
float,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
break;
case float16:
if (precise_) {
softmax<
float16_t,
float,
simd_float16,
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
eval(inputs, out); // Redirect to common backend for consistency
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
}
break;
case bfloat16:
eval(inputs, out);
break;
case complex64:
eval(inputs, out);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,28 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <Accelerate/Accelerate.h>
#include "mlx/dtype.h"
namespace mlx::core {
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
uint32_t size_bits = size_of(mlx_dtype) * 8;
switch (kindof(mlx_dtype)) {
case Dtype::Kind::b:
return BNNSDataTypeBoolean;
case Dtype::Kind::u:
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
case Dtype::Kind::i:
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
case Dtype::Kind::f:
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
case Dtype::Kind::V:
return BNNSDataTypeBFloat16;
case Dtype::Kind::c:
throw std::invalid_argument("BNNS does not support complex types");
}
}
} // namespace mlx::core

View File

@@ -1,8 +1,71 @@
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE)
else()
set(COMPILER ${CMAKE_CXX_COMPILER})
endif()
if(MSVC)
set(SHELL_EXT ps1)
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
else()
set(SHELL_EXT sh)
set(SHELL_CMD /bin/bash)
endif()
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
DEPENDS make_compiled_preamble.${SHELL_EXT}
compiled_preamble.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h)
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
add_dependencies(mlx cpu_compiled_preamble)
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if(IOS)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)
endif()

View File

@@ -62,9 +62,6 @@ void arange(
case float32:
arange<float>(start, start + step, out, out.size());
break;
case float64:
arange<double>(start, start + step, out, out.size());
break;
case bfloat16:
arange<bfloat16_t>(start, start + step, out, out.size());
break;

View File

@@ -2,8 +2,8 @@
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
#include "utils.h"
namespace mlx::core {
@@ -61,7 +61,7 @@ void arg_reduce_dispatch(
} // namespace
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
void ArgReduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -103,9 +103,6 @@ void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;

View File

@@ -5,9 +5,9 @@
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/binary_two.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/binary_two.h"
#include "mlx/backend/common/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -15,64 +15,69 @@ namespace mlx::core {
namespace {
template <typename T, typename U, typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
DefaultScalarVector<T, U, Op> opsv(op);
DefaultVectorScalar<T, U, Op> opvs(op);
DefaultVectorVector<T, U, Op> opvv(op);
binary_op<T, U>(a, b, out, op, opsv, opvs, opvv);
}
template <typename Op>
void comparison_op(const array& a, const array& b, array& out, Op op) {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool>(a, b, out, op);
comparison_op<bool, bool>(a, b, out, op);
break;
case uint8:
binary_op<uint8_t, bool>(a, b, out, op);
comparison_op<uint8_t, bool>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t, bool>(a, b, out, op);
comparison_op<uint16_t, bool>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t, bool>(a, b, out, op);
comparison_op<uint32_t, bool>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t, bool>(a, b, out, op);
comparison_op<uint64_t, bool>(a, b, out, op);
break;
case int8:
binary_op<int8_t, bool>(a, b, out, op);
comparison_op<int8_t, bool>(a, b, out, op);
break;
case int16:
binary_op<int16_t, bool>(a, b, out, op);
comparison_op<int16_t, bool>(a, b, out, op);
break;
case int32:
binary_op<int32_t, bool>(a, b, out, op);
comparison_op<int32_t, bool>(a, b, out, op);
break;
case int64:
binary_op<int64_t, bool>(a, b, out, op);
comparison_op<int64_t, bool>(a, b, out, op);
break;
case float16:
binary_op<float16_t, bool>(a, b, out, op);
comparison_op<float16_t, bool>(a, b, out, op);
break;
case float32:
binary_op<float, bool>(a, b, out, op);
break;
case float64:
binary_op<double, bool>(a, b, out, op);
comparison_op<float, bool>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, op);
comparison_op<bfloat16_t, bool>(a, b, out, op);
break;
case complex64:
binary_op<complex64_t, bool>(a, b, out, op);
comparison_op<complex64_t, bool>(a, b, out, op);
break;
}
}
} // namespace
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
void Add::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Add());
}
void DivMod::eval_cpu(
void DivMod::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 2);
@@ -117,9 +122,6 @@ void DivMod::eval_cpu(
case float32:
binary_op<float>(a, b, outputs, float_op);
break;
case float64:
binary_op<double>(a, b, outputs, float_op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, float_op);
break;
@@ -130,141 +132,118 @@ void DivMod::eval_cpu(
}
}
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
void Divide::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Divide());
}
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
void Remainder::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Remainder());
}
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
void Equal::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (equal_nan_) {
switch (a.dtype()) {
case float16:
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
break;
case float32:
binary_op<float, bool>(a, b, out, detail::NaNEqual());
break;
case float64:
binary_op<double, bool>(a, b, out, detail::NaNEqual());
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
break;
case complex64:
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual());
break;
default:
throw std::runtime_error(
"[NanEqual::eval_cpu] Only for floating point types.");
}
comparison_op(inputs[0], inputs[1], out, detail::NaNEqual());
} else {
comparison_op(a, b, out, detail::Equal());
comparison_op(inputs[0], inputs[1], out, detail::Equal());
}
}
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
void Greater::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater());
}
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
}
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
void Less::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less());
}
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
}
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::LogAddExp());
break;
case float32:
binary_op<float>(a, b, out, detail::LogAddExp());
break;
case float64:
binary_op<double>(a, b, out, detail::LogAddExp());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
break;
default:
throw std::runtime_error(
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with"
" non floating point type.");
}
}
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
}
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Maximum());
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
void Minimum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Minimum());
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
void Multiply::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Multiply());
}
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
void Power::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Power());
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
void Subtract::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
@@ -328,26 +307,24 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::ArcTan2());
break;
case float32:
binary_op<float>(a, b, out, detail::ArcTan2());
break;
case float64:
binary_op<double>(a, b, out, detail::ArcTan2());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
break;
default:
throw std::runtime_error(
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[arctan2] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[arctan2] Cannot compute inverse tangent for arrays"
" with non floating point type.");
}
}

View File

@@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h"
@@ -8,6 +9,8 @@
namespace mlx::core {
namespace {
enum class BinaryOpType {
ScalarScalar,
ScalarVector,
@@ -16,7 +19,7 @@ enum class BinaryOpType {
General,
};
inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
BinaryOpType get_binary_op_type(const array& a, const array& b) {
BinaryOpType bopt;
if (a.data_size() == 1 && b.data_size() == 1) {
bopt = BinaryOpType::ScalarScalar;
@@ -34,7 +37,7 @@ inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
return bopt;
}
inline void set_binary_op_output_data(
void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
@@ -119,4 +122,409 @@ inline void set_binary_op_output_data(
}
}
struct UseDefaultBinaryOp {};
template <typename T, typename U, typename Op>
struct DefaultVectorScalar {
Op op;
DefaultVectorScalar(Op op_) : op(op_) {}
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *b;
while (size-- > 0) {
*dst = op(*a, scalar);
dst++;
a++;
}
}
};
template <typename T, typename U, typename Op>
struct DefaultScalarVector {
Op op;
DefaultScalarVector(Op op_) : op(op_) {}
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *a;
while (size-- > 0) {
*dst = op(scalar, *b);
dst++;
b++;
}
}
};
template <typename T, typename U, typename Op>
struct DefaultVectorVector {
Op op;
DefaultVectorVector(Op op_) : op(op_) {}
void operator()(const T* a, const T* b, U* dst, int size) {
while (size-- > 0) {
*dst = op(*a, *b);
dst++;
a++;
b++;
}
}
};
template <typename T, typename U, typename Op, int D, bool Strided>
void binary_op_dims(
const T* a,
const T* b,
U* out,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
binary_op_dims<T, U, Op, D - 1, Strided>(
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
} else {
if constexpr (Strided) {
op(a, b, out, stride_out);
} else {
*out = op(*a, *b);
}
}
out += stride_out;
a += stride_a;
b += stride_b;
}
}
template <typename T, typename U, bool Strided, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op,
int dim,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
switch (dim) {
case 1:
binary_op_dims<T, U, Op, 1, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 2:
binary_op_dims<T, U, Op, 2, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 3:
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
ContiguousIterator a_it(shape, a_strides, dim - 3);
ContiguousIterator b_it(shape, b_strides, dim - 3);
auto stride = out_strides[dim - 4];
for (int64_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
dim - 3);
a_it.step();
b_it.step();
}
}
template <
typename T,
typename U,
typename Op,
typename OpSV,
typename OpVS,
typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
const auto& a_strides = new_strides[0];
const auto& b_strides = new_strides[1];
const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b_strides);
auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
if (dim == 0 || strides[dim - 1] < 16) {
bopt = BinaryOpType::General;
dim = ndim;
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true>(
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true>(
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true>(
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
break;
default:
binary_op_dispatch_dims<T, U, false>(
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
break;
}
}
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
void binary_op(
const array& a,
const array& b,
array& out,
Op op,
OpSV opsv,
OpVS opvs,
OpVV opvv) {
// TODO: The following mess of constexpr evaluations can probably be achieved
// with template specializations and overloading. Would it be simpler?
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv and opvs were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
DefaultVectorScalar<T, T, Op>(op),
opvv);
}
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opsv and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
DefaultScalarVector<T, T, Op>(op),
opvs,
DefaultVectorVector<T, T, Op>(op));
} else {
// opsv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
}
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
value) {
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
// opvs and opvv were UseDefaultBinaryOp
binary_op<T, T>(
a,
b,
out,
op,
opsv,
DefaultVectorScalar<T, T, Op>(op),
DefaultVectorVector<T, T, Op>(op));
} else {
// opvs was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
}
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
value) {
// opvv was UseDefaultBinaryOp
binary_op<T, T>(
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
} else {
// All ops provided
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
}
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
DefaultScalarVector<T, T, Op> opsv(op);
DefaultVectorScalar<T, T, Op> opvs(op);
DefaultVectorVector<T, T, Op> opvv(op);
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
}
template <typename... Ops>
void binary(const array& a, const array& b, array& out, Ops... ops) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, ops...);
break;
case uint8:
binary_op<uint8_t>(a, b, out, ops...);
break;
case uint16:
binary_op<uint16_t>(a, b, out, ops...);
break;
case uint32:
binary_op<uint32_t>(a, b, out, ops...);
break;
case uint64:
binary_op<uint64_t>(a, b, out, ops...);
break;
case int8:
binary_op<int8_t>(a, b, out, ops...);
break;
case int16:
binary_op<int16_t>(a, b, out, ops...);
break;
case int32:
binary_op<int32_t>(a, b, out, ops...);
break;
case int64:
binary_op<int64_t>(a, b, out, ops...);
break;
case float16:
binary_op<float16_t>(a, b, out, ops...);
break;
case float32:
binary_op<float>(a, b, out, ops...);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, ops...);
break;
case complex64:
binary_op<complex64_t>(a, b, out, ops...);
break;
}
}
} // namespace
} // namespace mlx::core

View File

@@ -2,8 +2,8 @@
#pragma once
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary.h"
namespace mlx::core {
@@ -205,9 +205,6 @@ void binary(
case float32:
binary_op<float>(a, b, outputs, op);
break;
case float64:
binary_op<double>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op);
break;

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
@@ -64,7 +64,7 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
}
}
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Cholesky::eval] only supports float32.");
}

View File

@@ -151,9 +151,6 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case float64:
*out.data<double>() = static_cast<double>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;

View File

@@ -2,16 +2,15 @@
#include <dlfcn.h>
#include <filesystem>
#include <format>
#include <fstream>
#include <list>
#include <mutex>
#include <shared_mutex>
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cpu/compiled_preamble.h"
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/backend/common/compiled_preamble.h"
#include "mlx/backend/common/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
@@ -111,7 +110,7 @@ void* compile(
JitCompiler::exec(JitCompiler::build_command(
output_dir, source_file_name, shared_lib_name));
} catch (const std::exception& error) {
throw std::runtime_error(fmt::format(
throw std::runtime_error(std::format(
"[Compile::eval_cpu] Failed to compile function {0}: {1}",
kernel_name,
error.what()));

View File

@@ -1,7 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/compile_impl.h"
#include "mlx/primitives.h"
#include "mlx/backend/common/compiled.h"
namespace mlx::core {

View File

@@ -5,8 +5,7 @@
// clang-format off
#include "mlx/types/half_types.h"
#include "mlx/types/complex.h"
#include "mlx/backend/cpu/unary_ops.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/common/ops.h"
// clang-format on
const char* get_kernel_preamble();

View File

@@ -3,8 +3,8 @@
#include <cassert>
#include <numeric>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -1128,7 +1128,7 @@ void conv_3D_cpu(
} // namespace
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
void Convolution::eval(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& in = inputs[0];

View File

@@ -3,9 +3,8 @@
#include <numeric>
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
@@ -24,7 +23,6 @@ template <typename SrcT, typename DstT>
void copy_vector(const array& src, array& dst) {
auto src_ptr = src.data<SrcT>();
auto dst_ptr = dst.data<DstT>();
size_t size = src.data_size();
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
}
@@ -193,9 +191,6 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
case float32:
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float64:
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
@@ -245,9 +240,6 @@ inline void copy_inplace_dispatch(
case float32:
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float64:
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
@@ -22,4 +23,17 @@ enum class CopyType {
GeneralGeneral
};
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
void copy_inplace(
const array& src,
array& dst,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core

View File

@@ -0,0 +1,198 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
#define DEFAULT(primitive) \
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
primitive::eval(inputs, out); \
}
#define DEFAULT_MULTI(primitive) \
void primitive::eval_cpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
primitive::eval(inputs, outputs); \
}
namespace mlx::core {
DEFAULT(Abs)
DEFAULT(Add)
DEFAULT(Arange)
DEFAULT(ArcCos)
DEFAULT(ArcCosh)
DEFAULT(ArcSin)
DEFAULT(ArcSinh)
DEFAULT(ArcTan)
DEFAULT(ArcTan2)
DEFAULT(ArcTanh)
DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BroadcastAxes)
DEFAULT(BlockMaskedMM)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(NumberOfElements)
DEFAULT(Remainder)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
DEFAULT(Exp)
DEFAULT(ExpandDims)
DEFAULT(Expm1)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Full)
DEFAULT(Gather)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
DEFAULT(Log)
DEFAULT(Log1p)
DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(Multiply)
DEFAULT(Negative)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT(Power)
DEFAULT_MULTI(QRF)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Sin)
DEFAULT(Sinh)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
DEFAULT(Square)
DEFAULT(Squeeze)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
DEFAULT(Subtract)
DEFAULT_MULTI(SVD)
DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
namespace {
inline void matmul_common_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[Matmul::eval_cpu] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_common_general(inputs[0], inputs[1], out);
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy(c, out, ctype);
return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core

View File

@@ -2,8 +2,8 @@
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
@@ -45,9 +45,7 @@ void ssyevd(
} // namespace
void Eigh::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];

View File

@@ -0,0 +1,40 @@
// Copyright © 2023 Apple Inc.
#include <cmath>
namespace mlx::core {
/* Approximation to the inverse error function.
* Based on code from:
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
*/
float erfinv(float a) {
auto t = std::fma(a, 0.0f - a, 1.0f);
t = std::log(t);
float p;
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
} else { // maximum ulp error = 2.35002
p = 5.43877832e-9f; // 0x1.75c000p-28
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
}
return a * p;
}
} // namespace mlx::core

View File

@@ -8,7 +8,7 @@
namespace mlx::core {
void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
void FFT::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
std::vector<std::ptrdiff_t> strides_in(
in.strides().begin(), in.strides().end());

View File

@@ -2,8 +2,8 @@
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -82,7 +82,7 @@ void hadamard(array& out, int n, int m, float scale) {
}
}
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -104,4 +104,4 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -6,8 +6,8 @@
#include "mlx/allocator.h"
#include "mlx/primitives.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
namespace mlx::core {
@@ -16,6 +16,11 @@ inline size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
template <>
inline size_t offset_neg_idx(bool idx, size_t) {
return idx;
}
template <>
inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
@@ -148,9 +153,6 @@ void dispatch_gather(
case float32:
gather<float, IdxT>(src, inds, out, axes, size);
break;
case float64:
gather<double, IdxT>(src, inds, out, axes, size);
break;
case bfloat16:
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
break;
@@ -160,18 +162,21 @@ void dispatch_gather(
}
}
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
void Gather::eval(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0];
std::vector<array> inds(inputs.begin() + 1, inputs.end());
if (inds.empty()) {
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
return;
}
switch (inds[0].dtype()) {
case bool_:
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
break;
case uint8:
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
break;
@@ -196,145 +201,12 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
case int64:
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
break;
default:
throw std::runtime_error(
"[Gather::eval_cpu] Cannot gather with indices type.");
break;
}
}
template <typename T, typename IdxT>
void gather_axis(
const array& src,
const array& ind,
array& out,
const int axis) {
auto strides = ind.strides();
strides.erase(strides.begin() + axis);
auto shape = ind.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
strides = src.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
auto ind_ptr = ind.data<IdxT>();
auto src_ptr = src.data<T>();
auto dst_ptr = out.data<T>();
auto ind_ax_stride = ind.strides(axis);
auto src_ax_stride = src.strides(axis);
auto dst_ax_stride = out.strides(axis);
auto ind_ax_size = ind.shape(axis);
auto src_ax_size = src.shape(axis);
size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis; ++i) {
size_pre *= ind.shape(i);
}
for (int i = axis + 1; i < ind.ndim(); ++i) {
size_post *= ind.shape(i);
}
size_t stride_pre = size_post * ind_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < ind_ax_size; ++j) {
auto ind_val = offset_neg_idx(
ind_ptr[ind_it.loc + j * ind_ax_stride], src_ax_size);
dst_ptr[k + j * dst_ax_stride] =
src_ptr[src_it.loc + ind_val * src_ax_stride];
}
ind_it.step();
src_it.step();
}
dst_ptr += stride_pre;
}
}
template <typename IdxT>
void dispatch_gather_axis(
const array& src,
const array& inds,
array& out,
const int axis) {
switch (out.dtype()) {
case bool_:
gather_axis<bool, IdxT>(src, inds, out, axis);
break;
case uint8:
gather_axis<uint8_t, IdxT>(src, inds, out, axis);
break;
case uint16:
gather_axis<uint16_t, IdxT>(src, inds, out, axis);
break;
case uint32:
gather_axis<uint32_t, IdxT>(src, inds, out, axis);
break;
case uint64:
gather_axis<uint64_t, IdxT>(src, inds, out, axis);
break;
case int8:
gather_axis<int8_t, IdxT>(src, inds, out, axis);
break;
case int16:
gather_axis<int16_t, IdxT>(src, inds, out, axis);
break;
case int32:
gather_axis<int32_t, IdxT>(src, inds, out, axis);
break;
case int64:
gather_axis<int64_t, IdxT>(src, inds, out, axis);
break;
case float16:
gather_axis<float16_t, IdxT>(src, inds, out, axis);
break;
case float32:
gather_axis<float, IdxT>(src, inds, out, axis);
break;
case float64:
gather_axis<double, IdxT>(src, inds, out, axis);
break;
case bfloat16:
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
break;
case complex64:
gather_axis<complex64_t, IdxT>(src, inds, out, axis);
break;
}
}
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& src = inputs[0];
auto& inds = inputs[1];
switch (inds.dtype()) {
case uint8:
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
break;
case uint16:
dispatch_gather_axis<uint16_t>(src, inds, out, axis_);
break;
case uint32:
dispatch_gather_axis<uint32_t>(src, inds, out, axis_);
break;
case uint64:
dispatch_gather_axis<uint64_t>(src, inds, out, axis_);
break;
case int8:
dispatch_gather_axis<int8_t>(src, inds, out, axis_);
break;
case int16:
dispatch_gather_axis<int16_t>(src, inds, out, axis_);
break;
case int32:
dispatch_gather_axis<int32_t>(src, inds, out, axis_);
break;
case int64:
dispatch_gather_axis<int64_t>(src, inds, out, axis_);
break;
default:
throw std::runtime_error(
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
"[Gather::eval] Cannot gather with floating point indices.");
break;
}
}
@@ -424,11 +296,14 @@ void dispatch_scatter(
const std::vector<int>& axes,
Scatter::ReduceType rtype) {
if (inds.empty()) {
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
return;
}
switch (inds[0].dtype()) {
case bool_:
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
break;
case uint8:
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
break;
@@ -453,13 +328,16 @@ void dispatch_scatter(
case int64:
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
break;
default:
case float16:
case float32:
case bfloat16:
case complex64:
throw std::runtime_error(
"[Scatter::eval_cpu] Cannot scatter with indices type.");
"[Scatter::eval_cpu] Cannot scatter with floating point indices.");
}
}
void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
void Scatter::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2);
auto& src = inputs[0];
@@ -467,9 +345,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& updates = inputs.back();
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype);
copy(src, out, CopyType::General);
switch (src.dtype()) {
case bool_:
@@ -505,9 +381,6 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
break;
case float64:
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
break;
@@ -517,170 +390,4 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
template <typename T, typename IdxT, typename OpT>
void scatter_axis(
array& out,
const array idx,
const array& upd,
int axis,
const OpT& op) {
auto strides = idx.strides();
strides.erase(strides.begin() + axis);
auto shape = idx.shape();
shape.erase(shape.begin() + axis);
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
strides = upd.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
auto idx_ptr = idx.data<IdxT>();
auto upd_ptr = upd.data<T>();
auto dst_ptr = out.data<T>();
auto idx_ax_stride = idx.strides(axis);
auto upd_ax_stride = upd.strides(axis);
auto dst_ax_stride = out.strides(axis);
auto idx_ax_size = idx.shape(axis);
auto dst_ax_size = out.shape(axis);
size_t size_pre = 1;
size_t size_post = 1;
for (int i = 0; i < axis; ++i) {
size_pre *= idx.shape(i);
}
for (int i = axis + 1; i < idx.ndim(); ++i) {
size_post *= idx.shape(i);
}
size_t stride_pre = size_post * dst_ax_size;
for (size_t i = 0; i < size_pre; i++) {
for (size_t k = 0; k < size_post; k++) {
for (int j = 0; j < idx_ax_size; ++j) {
auto ind_val = offset_neg_idx(
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
dst_ptr + k + ind_val * dst_ax_stride);
}
idx_it.step();
upd_it.step();
}
dst_ptr += stride_pre;
}
}
template <typename InT, typename IdxT>
void dispatch_scatter_axis_op(
array& out,
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype) {
switch (rtype) {
case ScatterAxis::None:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; });
break;
case ScatterAxis::Sum:
scatter_axis<InT, IdxT>(
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; });
break;
}
}
template <typename InT>
void dispatch_scatter_axis(
array& out,
const array& idx,
const array& updates,
int axis,
ScatterAxis::ReduceType rtype) {
switch (idx.dtype()) {
case uint8:
dispatch_scatter_axis_op<InT, uint8_t>(out, idx, updates, axis, rtype);
break;
case uint16:
dispatch_scatter_axis_op<InT, uint16_t>(out, idx, updates, axis, rtype);
break;
case uint32:
dispatch_scatter_axis_op<InT, uint32_t>(out, idx, updates, axis, rtype);
break;
case uint64:
dispatch_scatter_axis_op<InT, uint64_t>(out, idx, updates, axis, rtype);
break;
case int8:
dispatch_scatter_axis_op<InT, int8_t>(out, idx, updates, axis, rtype);
break;
case int16:
dispatch_scatter_axis_op<InT, int16_t>(out, idx, updates, axis, rtype);
break;
case int32:
dispatch_scatter_axis_op<InT, int32_t>(out, idx, updates, axis, rtype);
break;
case int64:
dispatch_scatter_axis_op<InT, int64_t>(out, idx, updates, axis, rtype);
break;
default:
throw std::runtime_error(
"[ScatterAxis::eval_cpu] Cannot scatter with indices type.");
}
}
void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() >= 2);
auto& src = inputs[0];
auto& idx = inputs[1];
auto& updates = inputs[2];
// Copy src into out (copy allocates memory for out)
auto ctype =
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy(src, out, ctype);
switch (src.dtype()) {
case bool_:
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
break;
case uint8:
dispatch_scatter_axis<uint8_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint16:
dispatch_scatter_axis<uint16_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint32:
dispatch_scatter_axis<uint32_t>(out, idx, updates, axis_, reduce_type_);
break;
case uint64:
dispatch_scatter_axis<uint64_t>(out, idx, updates, axis_, reduce_type_);
break;
case int8:
dispatch_scatter_axis<int8_t>(out, idx, updates, axis_, reduce_type_);
break;
case int16:
dispatch_scatter_axis<int16_t>(out, idx, updates, axis_, reduce_type_);
break;
case int32:
dispatch_scatter_axis<int32_t>(out, idx, updates, axis_, reduce_type_);
break;
case int64:
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
break;
case float16:
dispatch_scatter_axis<float16_t>(out, idx, updates, axis_, reduce_type_);
break;
case float32:
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
break;
case float64:
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
break;
case bfloat16:
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
break;
case complex64:
dispatch_scatter_axis<complex64_t>(
out, idx, updates, axis_, reduce_type_);
break;
}
}
} // namespace mlx::core

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
@@ -110,7 +110,7 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
}
}
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
void Inverse::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32.");
}

View File

@@ -1,11 +1,12 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/backend/common/jit_compiler.h"
#include <algorithm>
#include <sstream>
#include <vector>
#include <fmt/format.h>
#include <format>
namespace mlx::core {
@@ -33,7 +34,7 @@ struct VisualStudioInfo {
arch = "x64";
#endif
// Get path of Visual Studio.
std::string vs_path = JitCompiler::exec(fmt::format(
std::string vs_path = JitCompiler::exec(std::format(
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
" -property installationPath",
std::getenv("ProgramFiles(x86)")));
@@ -41,7 +42,7 @@ struct VisualStudioInfo {
throw std::runtime_error("Can not find Visual Studio.");
}
// Read the envs from vcvarsall.
std::string envs = JitCompiler::exec(fmt::format(
std::string envs = JitCompiler::exec(std::format(
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
vs_path,
arch));
@@ -54,8 +55,8 @@ struct VisualStudioInfo {
std::string value = line.substr(pos + 1);
if (name == "LIB") {
libpaths = str_split(value, ';');
} else if (name == "VCToolsInstallDir" || name == "VCTOOLSINSTALLDIR") {
cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
} else if (name == "VCToolsInstallDir") {
cl_exe = std::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
}
}
}
@@ -81,9 +82,9 @@ std::string JitCompiler::build_command(
const VisualStudioInfo& info = GetVisualStudioInfo();
std::string libpaths;
for (const std::string& lib : info.libpaths) {
libpaths += fmt::format(" /libpath:\"{0}\"", lib);
libpaths += std::format(" /libpath:\"{0}\"", lib);
}
return fmt::format(
return std::format(
"\""
"cd /D \"{0}\" && "
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
@@ -95,8 +96,8 @@ std::string JitCompiler::build_command(
shared_lib_name,
libpaths);
#else
return fmt::format(
"g++ -std=c++17 -O3 -Wall -fPIC -shared \"{0}\" -o \"{1}\" 2>&1",
return std::format(
"g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}' 2>&1",
(dir / source_file_name).string(),
(dir / shared_lib_name).string());
#endif
@@ -133,13 +134,13 @@ std::string JitCompiler::exec(const std::string& cmd) {
if (status == -1) {
throw std::runtime_error("pclose() failed.");
}
#if defined(_WIN32) || defined(__FreeBSD__)
#ifdef _MSC_VER
int code = status;
#else
int code = WEXITSTATUS(status);
#endif
if (code != 0) {
throw std::runtime_error(fmt::format(
throw std::runtime_error(std::format(
"Failed to execute command with return code {0}: \"{1}\", "
"the output is: {2}",
code,

View File

@@ -11,7 +11,7 @@
#define lapack_complex_double std::complex<double>
#endif
#ifdef MLX_USE_ACCELERATE
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>

View File

@@ -1,9 +1,12 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <utility>
#include "mlx/allocator.h"
#include "mlx/backend/common/load.h"
#include "mlx/primitives.h"
namespace {
@@ -48,4 +51,11 @@ void load(
}
}
void Load::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
}
} // namespace mlx::core

View File

@@ -8,7 +8,7 @@ $CL = $args[1]
$SRCDIR = $args[2]
# Get command result as array.
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/cpu/compiled_preamble.h"
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/compiled_preamble.h"
# Remove empty lines.
# Otherwise there will be too much empty lines making the result unreadable.
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }

View File

@@ -24,7 +24,7 @@ else
CC_FLAGS="-std=c++17"
fi
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 2>/dev/null)
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() {

View File

@@ -3,9 +3,9 @@
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -53,7 +53,7 @@ inline void mask_matrix(
} // namespace
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32.");
@@ -210,7 +210,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
void GatherMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[GatherMM::eval] Currently only supports float32.");

680
mlx/backend/common/ops.h Normal file
View File

@@ -0,0 +1,680 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <stdint.h>
#include <cmath>
#include <complex>
namespace mlx::core::detail {
namespace {
constexpr float inf = std::numeric_limits<float>::infinity();
} // namespace
typedef union {
int i;
float f;
} IntOrFloat;
inline float fast_exp(float x) {
if (x == -std::numeric_limits<float>::infinity()) {
return 0.0f;
} else if (x == std::numeric_limits<float>::infinity() || std::isnan(x)) {
return x;
}
x *= 1.442695; // multiply with log_2(e)
float ipart, fpart;
IntOrFloat epart;
x = std::max(-80.f, std::min(x, 80.f));
ipart = std::floor(x + 0.5);
fpart = x - ipart;
x = 1.535336188319500e-4f;
x = x * fpart + 1.339887440266574e-3f;
x = x * fpart + 9.618437357674640e-3f;
x = x * fpart + 5.550332471162809e-2f;
x = x * fpart + 2.402264791363012e-1f;
x = x * fpart + 6.931472028550421e-1f;
x = x * fpart + 1.000000000000000f;
// generate 2**ipart in the floating point representation using integer
// bitshifting
epart.i = (int(ipart) + 127) << 23;
return epart.f * x;
}
inline float fast_erf(float a) {
float r, s, t, u;
t = std::abs(a);
s = a * a;
if (t > 0.927734375f) {
// maximum error 0.99527 ulp
r = std::fma(
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
u = std::fma(
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
r = std::fma(r, s, u);
r = std::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
r = std::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
r = std::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
r = std::fma(r, t, -t);
// TODO, replace with expm1 when implemented
r = 1.0f - std::exp(r);
r = std::copysign(r, a);
} else {
// maximum error 0.98929 ulp
r = -5.96761703e-4f; // -0x1.38e000p-11
r = std::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
r = std::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
r = std::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
r = std::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
r = std::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
r = std::fma(r, a, a);
}
return r;
}
inline float fast_erfinv(float a) {
auto t = std::fma(a, 0.0f - a, 1.0f);
t = std::log(t);
float p;
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
} else { // maximum ulp error = 2.35002
p = 5.43877832e-9f; // 0x1.75c000p-28
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
}
return a * p;
}
struct Abs {
template <typename T>
T operator()(T x) {
return std::abs(x);
}
uint8_t operator()(uint8_t x) {
return x;
}
uint16_t operator()(uint16_t x) {
return x;
}
uint32_t operator()(uint32_t x) {
return x;
}
uint64_t operator()(uint64_t x) {
return x;
}
bool operator()(bool x) {
return x;
}
};
struct ArcCos {
template <typename T>
T operator()(T x) {
return std::acos(x);
}
};
struct ArcCosh {
template <typename T>
T operator()(T x) {
return std::acosh(x);
}
};
struct ArcSin {
template <typename T>
T operator()(T x) {
return std::asin(x);
}
};
struct ArcSinh {
template <typename T>
T operator()(T x) {
return std::asinh(x);
}
};
struct ArcTan {
template <typename T>
T operator()(T x) {
return std::atan(x);
}
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return std::atan2(y, x);
}
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
return std::atanh(x);
}
};
struct Ceil {
template <typename T>
T operator()(T x) {
return std::ceil(x);
}
int8_t operator()(int8_t x) {
return x;
}
int16_t operator()(int16_t x) {
return x;
}
int32_t operator()(int32_t x) {
return x;
}
int64_t operator()(int64_t x) {
return x;
}
uint8_t operator()(uint8_t x) {
return x;
}
uint16_t operator()(uint16_t x) {
return x;
}
uint32_t operator()(uint32_t x) {
return x;
}
uint64_t operator()(uint64_t x) {
return x;
}
bool operator()(bool x) {
return x;
}
};
struct Conjugate {
complex64_t operator()(complex64_t x) {
return std::conj(x);
}
};
struct Cos {
template <typename T>
T operator()(T x) {
return std::cos(x);
}
};
struct Cosh {
template <typename T>
T operator()(T x) {
return std::cosh(x);
}
};
struct Erf {
template <typename T>
T operator()(T x) {
return static_cast<T>(fast_erf(static_cast<float>(x)));
}
};
struct ErfInv {
template <typename T>
T operator()(T x) {
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
}
};
struct Exp {
template <typename T>
T operator()(T x) {
return fast_exp(x);
}
complex64_t operator()(complex64_t x) {
return std::exp(x);
}
};
struct Expm1 {
template <typename T>
T operator()(T x) {
return expm1(x);
}
};
struct Floor {
template <typename T>
T operator()(T x) {
return std::floor(x);
}
int8_t operator()(int8_t x) {
return x;
}
int16_t operator()(int16_t x) {
return x;
}
int32_t operator()(int32_t x) {
return x;
}
int64_t operator()(int64_t x) {
return x;
}
uint8_t operator()(uint8_t x) {
return x;
}
uint16_t operator()(uint16_t x) {
return x;
}
uint32_t operator()(uint32_t x) {
return x;
}
uint64_t operator()(uint64_t x) {
return x;
}
bool operator()(bool x) {
return x;
}
};
struct Imag {
template <typename T>
T operator()(T x) {
return std::imag(x);
}
};
struct Log {
template <typename T>
T operator()(T x) {
return std::log(x);
}
};
struct Log2 {
template <typename T>
T operator()(T x) {
return std::log2(x);
}
};
struct Log10 {
template <typename T>
T operator()(T x) {
return std::log10(x);
}
};
struct Log1p {
template <typename T>
T operator()(T x) {
return log1p(x);
}
};
struct LogicalNot {
template <typename T>
T operator()(T x) {
return !x;
}
};
struct Negative {
template <typename T>
T operator()(T x) {
return -x;
}
};
struct Real {
template <typename T>
T operator()(T x) {
return std::real(x);
}
};
struct Round {
template <typename T>
T operator()(T x) {
return std::rint(x);
}
complex64_t operator()(complex64_t x) {
return {std::rint(x.real()), std::rint(x.imag())};
}
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
auto one = static_cast<decltype(x)>(1.0);
return one / (one + fast_exp(-x));
}
};
struct Sign {
template <typename T>
T operator()(T x) {
return (x > T(0)) - (x < T(0));
}
uint8_t operator()(uint8_t x) {
return x != 0;
}
uint16_t operator()(uint16_t x) {
return x != 0;
}
uint32_t operator()(uint32_t x) {
return x != 0;
}
uint64_t operator()(uint64_t x) {
return x != 0;
}
complex64_t operator()(complex64_t x) {
return x == complex64_t(0) ? x : x / std::abs(x);
}
};
struct Sin {
template <typename T>
T operator()(T x) {
return std::sin(x);
}
};
struct Sinh {
template <typename T>
T operator()(T x) {
return std::sinh(x);
}
};
struct Square {
template <typename T>
T operator()(T x) {
return x * x;
}
};
struct Sqrt {
template <typename T>
T operator()(T x) {
return std::sqrt(x);
}
};
struct Rsqrt {
template <typename T>
T operator()(T x) {
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
}
};
struct Tan {
template <typename T>
T operator()(T x) {
return std::tan(x);
}
};
struct Tanh {
template <typename T>
T operator()(T x) {
return std::tanh(x);
}
};
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
T numerator,
T denominator) {
return numerator % denominator;
}
template <typename T>
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = numerator % denominator;
if (r != 0 && (r < 0 != denominator < 0))
r += denominator;
return r;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
T numerator,
T denominator) {
auto r = std::fmod(numerator, denominator);
if (r != 0 && (r < 0 != denominator < 0)) {
r += denominator;
}
return r;
}
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
return numerator % denominator;
}
};
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
if constexpr (std::is_integral_v<T>) {
// isnan always returns false for integers, and MSVC refuses to compile.
return x == y;
} else {
return x == y || (std::isnan(x) && std::isnan(y));
}
}
};
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct Maximum {
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
return (x > y) ? x : y;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
if (std::isnan(x)) {
return x;
}
return (x > y) ? x : y;
}
};
struct Minimum {
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T x, T y) {
return x < y ? x : y;
}
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T x, T y) {
if (std::isnan(x)) {
return x;
}
return x < y ? x : y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
constexpr float inf = std::numeric_limits<float>::infinity();
auto maxval = Maximum()(x, y);
auto minval = Minimum()(x, y);
return (minval == -inf || maxval == inf)
? maxval
: static_cast<decltype(x)>(
maxval + std::log1p(fast_exp(minval - maxval)));
}
};
struct Multiply {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
};
struct Power {
template <typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(T base, T exp) {
return std::pow(base, exp);
}
template <typename T>
std::enable_if_t<std::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
};
struct Subtract {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
}
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
}
};
struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return condition ? x : y;
}
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
}
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
}
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
}
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
}
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
}
};
} // namespace mlx::core::detail

View File

@@ -0,0 +1,714 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/backend/common/arange.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/threefry.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
int64_t compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes) {
auto compute_offset = [&strides, &axes](const auto* indices) {
int64_t offset = 0;
for (int i = 0; i < axes.size(); ++i) {
offset += indices[i] * strides[axes[i]];
}
return offset;
};
switch (indices.dtype()) {
case int8:
case uint8:
return compute_offset(indices.data<uint8_t>());
case int16:
case uint16:
return compute_offset(indices.data<uint16_t>());
case int32:
case uint32:
return compute_offset(indices.data<uint32_t>());
case int64:
case uint64:
return compute_offset(indices.data<uint64_t>());
default:
throw std::runtime_error("Invalid indices type.");
}
}
void Abs::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), unsignedinteger)) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Abs());
}
}
void Arange::eval(const std::vector<array>& inputs, array& out) {
arange(inputs, out, start_, step_);
}
void ArcCos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCos());
} else {
throw std::invalid_argument(
"[arccos] Cannot compute inverse cosine of elements in array"
" with non floating point type.");
}
}
void ArcCosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcCosh());
} else {
throw std::invalid_argument(
"[arccosh] Cannot compute inverse hyperbolic cosine of elements in"
" array with non floating point type.");
}
}
void ArcSin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSin());
} else {
throw std::invalid_argument(
"[arcsin] Cannot compute inverse sine of elements in array"
" with non floating point type.");
}
}
void ArcSinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcSinh());
} else {
throw std::invalid_argument(
"[arcsinh] Cannot compute inverse hyperbolic sine of elements in"
" array with non floating point type.");
}
}
void ArcTan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTan());
} else {
throw std::invalid_argument(
"[arctan] Cannot compute inverse tangent of elements in array"
" with non floating point type.");
}
}
void ArcTanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::ArcTanh());
} else {
throw std::invalid_argument(
"[arctanh] Cannot compute inverse hyperbolic tangent of elements in"
" array with non floating point type.");
}
}
void AsType::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
}
void Ceil::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Concatenate::eval(const std::vector<array>& inputs, array& out) {
std::vector<int> sizes;
sizes.push_back(0);
for (auto& p : inputs) {
sizes.push_back(p.shape(axis_));
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral);
}
}
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_fp(in, out, detail::Conjugate());
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
}
}
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
}
}
void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cos());
} else {
throw std::invalid_argument(
"[cos] Cannot compute cosine of elements in array"
" with non floating point type.");
}
}
void Cosh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Cosh());
} else {
throw std::invalid_argument(
"[cosh] Cannot compute hyperbolic cosine of elements in array"
" with non floating point type.");
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::Erf());
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
}
void ErfInv::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::ErfInv());
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
}
void Exp::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Exp());
} else {
throw std::invalid_argument(
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
void Expm1::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Expm1());
} else {
throw std::invalid_argument(
"[expm1] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Floor::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Full::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
assert(in.dtype() == out.dtype());
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy(in, out, ctype);
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log());
break;
case Base::two:
unary_fp(in, out, detail::Log2());
break;
case Base::ten:
unary_fp(in, out, detail::Log10());
break;
}
} else {
throw std::invalid_argument(
"[log] Cannot compute log of elements in array with"
" non floating point type.");
}
}
void Log1p::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Log1p());
} else {
throw std::invalid_argument(
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
}
}
void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::LogicalNot());
}
void Negative::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Negative());
}
void Pad::eval(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy(val, out, CopyType::Scalar);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes_.size(); i++) {
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
data_offset += out.strides()[ax] * low_pad_size_[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
}
void RandomBits::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
size_t num_keys = keys.size() / 2;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>();
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);
// Get ith key
auto kidx = 2 * i;
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides());
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides());
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
for (; count.first + 1 < half_size; count.first++, count.second++) {
std::tie(ptr[count.first], ptr[count.second]) =
random::threefry2x32_hash(key, count);
}
if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) {
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
}
if (!even) {
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
}
}
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Round::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sigmoid::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sigmoid());
} else {
throw std::invalid_argument(
"[sigmoid] Cannot sigmoid of elements in array with"
" non floating point type.");
}
}
void Sign::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == bool_) {
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Sign());
}
}
void Sin::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sin());
} else {
throw std::invalid_argument(
"[sin] Cannot compute sine of elements in array"
" with non floating point type.");
}
}
void Sinh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Sinh());
} else {
throw std::invalid_argument(
"[sinh] Cannot compute hyperbolic sine of elements in array"
" with non floating point type.");
}
}
void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
size_t data_size = data_end - data_offset;
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_);
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ i_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
// Copy or move src to dst
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_);
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ o_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
// Do copy
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void Square::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Square());
}
void Sqrt::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (recip_) {
unary_fp(in, out, detail::Rsqrt());
} else {
unary_fp(in, out, detail::Sqrt());
}
}
void Tan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tan());
} else {
throw std::invalid_argument(
"[tan] Cannot compute tangent of elements in array"
" with non floating point type.");
}
}
void Tanh::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, detail::Tanh());
} else {
throw std::invalid_argument(
"[tanh] Cannot compute hyperbolic tangent of elements in array"
" with non floating point type.");
}
}
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_inplace(in_tmp, tmp, CopyType::General);
} else {
copy_inplace(in, tmp, CopyType::General);
}
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -41,7 +41,7 @@ template <typename T>
void qrf_impl(const array& a, array& q, array& r) {
const int M = a.shape(-2);
const int N = a.shape(-1);
const int lda = M;
const int lda = std::max(M, N);
size_t num_matrices = a.size() / (M * N);
int num_reflectors = std::min(M, N);
auto tau =
@@ -89,16 +89,13 @@ void qrf_impl(const array& a, array& q, array& r) {
allocator::free(work);
r.set_data(allocator::malloc_or_wait(r.nbytes()));
copy_inplace(in, r, CopyType::General);
for (int i = 0; i < num_matrices; ++i) {
/// num_reflectors x N
// Zero lower triangle
for (int j = 0; j < r.shape(-2); ++j) {
for (int k = 0; k < j; ++k) {
r.data<T>()[i * N * num_reflectors + j * N + k] = 0;
}
for (int k = j; k < r.shape(-1); ++k) {
r.data<T>()[i * N * num_reflectors + j * N + k] =
in.data<T>()[i * N * M + j + k * M];
r.data<T>()[i * N * M + j * N + k] = 0;
}
}
}
@@ -107,7 +104,7 @@ void qrf_impl(const array& a, array& q, array& r) {
lwork = -1;
lpack<T>::xorgqr(
&M,
&num_reflectors,
&N,
&num_reflectors,
nullptr,
&lda,
@@ -123,7 +120,7 @@ void qrf_impl(const array& a, array& q, array& r) {
// Compute Q
lpack<T>::xorgqr(
&M,
&num_reflectors,
&N,
&num_reflectors,
in.data<float>() + M * N * i,
&lda,
@@ -134,24 +131,14 @@ void qrf_impl(const array& a, array& q, array& r) {
}
q.set_data(allocator::malloc_or_wait(q.nbytes()));
for (int i = 0; i < num_matrices; ++i) {
// M x num_reflectors
for (int j = 0; j < q.shape(-2); ++j) {
for (int k = 0; k < q.shape(-1); ++k) {
q.data<T>()[i * M * num_reflectors + j * num_reflectors + k] =
in.data<T>()[i * N * M + j + k * M];
}
}
}
copy_inplace(in, q, CopyType::General);
// Cleanup
allocator::free(work);
allocator::free(tau);
}
void QRF::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
void QRF::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[QRF::eval] only supports float32.");
}

View File

@@ -2,8 +2,8 @@
#include <cassert>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -151,78 +151,6 @@ void _qmm_t(
}
}
template <int bits, int S>
simd::Simd<uint32_t, S> extract_bits_simd(const uint32_t* w) {
constexpr int bitmask = (1 << bits) - 1;
simd::Simd<uint32_t, S> wi;
if constexpr (bits == 4 && S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
wi = simd::Simd<uint32_t, S>(*w);
wi = wi >> shifts;
wi = wi & bitmask;
} else if constexpr (bits == 8 && S == 8) {
constexpr std::array<uint32_t, 8> shifts_ = {{0, 8, 16, 24, 0, 8, 16, 24}};
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
auto l = simd::Simd<uint32_t, 4>(*w++);
auto r = simd::Simd<uint32_t, 4>(*w);
wi = simd::Simd<uint32_t, S>(l, r);
wi = wi >> shifts;
wi = wi & bitmask;
} else {
// Appease compiler.. but should never get here
throw std::runtime_error("Unsupported combination for simd qmm.");
}
return wi;
}
template <typename T, int bits, int group_size>
void _qmm_t_simd(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
constexpr int S = simd::max_size<T>;
static_assert(
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
constexpr int packs_per_simd = S / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
for (int n = 0; n < N; n++) {
simd::Simd<float, S> acc(0);
auto x_local = x;
for (int k = 0; k < K; k += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
auto wf = simd::Simd<float, S>(extract_bits_simd<bits, S>(w_local));
w_local += packs_per_simd;
wf = wf * scale;
wf = wf + bias;
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
acc = acc + x_simd * wf;
x_local += S;
}
}
*result = T(simd::sum(acc));
result++;
}
x += K;
}
}
template <typename T, int bits, int group_size>
void _qmm_dispatch_transpose(
T* result,
@@ -235,14 +163,9 @@ void _qmm_dispatch_transpose(
int K,
bool transposed_w) {
if (transposed_w) {
// the simd size must be a multiple of the number of elements per word
if constexpr (32 % bits == 0 && simd::max_size<T> % (32 / bits) == 0) {
_qmm_t_simd<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} else {
_qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} else {
_qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
}
@@ -326,13 +249,13 @@ void _qmm_dispatch(
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
int batch_size = x.size() / x.shape(-1) / x.shape(-2);
for (int i = 0; i < batch_size; i++) {
switch (x.dtype()) {
case float32:
@@ -461,7 +384,7 @@ void _bs_qmm_dispatch(
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x_pre = inputs[0];
@@ -488,7 +411,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];

View File

@@ -1,147 +1,312 @@
// Copyright © 2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <functional>
#include <limits>
#include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h"
namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
namespace {
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
strides.erase(strides.begin() + a);
template <typename U>
struct Limits {
static const U max;
static const U min;
};
#define instantiate_default_limit(type) \
template <> \
struct Limits<type> { \
static constexpr type max = std::numeric_limits<type>::max(); \
static constexpr type min = std::numeric_limits<type>::min(); \
};
instantiate_default_limit(uint8_t);
instantiate_default_limit(uint16_t);
instantiate_default_limit(uint32_t);
instantiate_default_limit(uint64_t);
instantiate_default_limit(int8_t);
instantiate_default_limit(int16_t);
instantiate_default_limit(int32_t);
instantiate_default_limit(int64_t);
#define instantiate_float_limit(type) \
template <> \
struct Limits<type> { \
static const type max; \
static const type min; \
};
instantiate_float_limit(float16_t);
instantiate_float_limit(bfloat16_t);
instantiate_float_limit(float);
instantiate_float_limit(complex64_t);
template <>
struct Limits<bool> {
static constexpr bool max = true;
static constexpr bool min = false;
};
const float Limits<float>::max = std::numeric_limits<float>::infinity();
const float Limits<float>::min = -std::numeric_limits<float>::infinity();
const bfloat16_t Limits<bfloat16_t>::max =
std::numeric_limits<float>::infinity();
const bfloat16_t Limits<bfloat16_t>::min =
-std::numeric_limits<float>::infinity();
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
const float16_t Limits<float16_t>::min =
-std::numeric_limits<float>::infinity();
const complex64_t Limits<complex64_t>::max =
std::numeric_limits<float>::infinity();
const complex64_t Limits<complex64_t>::min =
-std::numeric_limits<float>::infinity();
struct AndReduce {
template <typename T>
void operator()(bool* a, T b) {
(*a) &= (b != 0);
}
return std::make_pair(shape, strides);
void operator()(bool* y, bool x) {
(*y) &= x;
}
};
struct OrReduce {
template <typename T>
void operator()(bool* a, T b) {
(*a) |= (b != 0);
}
void operator()(bool* y, bool x) {
(*y) |= x;
}
};
struct MaxReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y > x) ? *y : x;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y > x) ? *y : x;
}
};
};
struct MinReduce {
template <typename T>
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
(*y) = (*y < x) ? *y : x;
};
template <typename T>
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
if (std::isnan(x)) {
*y = x;
} else {
(*y) = (*y < x) ? *y : x;
}
};
};
template <typename InT>
void reduce_dispatch_and_or(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::And) {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
} else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
}
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
return ContiguousAllReduce;
template <typename InT>
void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
}
} else {
auto op = [](auto y, auto x) { (*y) *= x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 1, op);
} else {
reduction_op<InT, InT>(in, out, axes, 1, op);
}
}
}
// Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) {
// Merge consecutive axes
Shape shape = {x.shape(axes[0])};
Strides strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];
} else {
shape.push_back(x.shape(axes[i]));
strides.push_back(x.strides()[axes[i]]);
template <typename InT>
void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
}
}
} // namespace
void nd_loop(
std::function<void(int)> callback,
const Shape& shape,
const Strides& strides) {
std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) {
auto size = shape[dim];
auto stride = strides[dim];
for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride);
}
} else {
auto size = shape[dim];
auto stride = strides[dim];
for (int i = 0; i < size; i++) {
callback(offset + i * stride);
}
}
};
loop_inner(0, 0);
}
// Remove singleton axes from the plan
for (int i = shape.size() - 1; i >= 0; i--) {
if (shape[i] == 1) {
shape.erase(shape.begin() + i);
strides.erase(strides.begin() + i);
void Reduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) {
return ReductionPlan(ContiguousStridedReduce, shape, strides);
}
}
// Let's check if we can optimize our access patterns
//
// 1. We have a reduction axis with stride 1. Simply call
// GeneralContiguousReduce and be done with it.
// 2. We have transpositions and we are not reducing over the axis with
// stride 1. However, we are reducing over an axis where everything is
// contiguous in memory to the right of that axis. We can call strided
// reduce and be done with it.
// 2. We have weird transpositions and expands. Copy the strides to the
// output, then call strided reduce.
// Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction.
std::vector<std::pair<int, int64_t>> reductions;
for (auto a : axes) {
if (x.shape(a) > 1) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
}
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
bool a_is_zero = a.second == 0;
bool b_is_zero = b.second == 0;
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
});
// Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis.
for (int i = reductions.size() - 1; i >= 1; i--) {
auto a = reductions[i];
auto b = reductions[i - 1];
// b.stride = a.shape * a.stride then a and b are contiguous
if (b.second == a.first * a.second) {
reductions.erase(reductions.begin() + i);
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
}
}
Shape shape;
Strides strides;
for (auto r : reductions) {
shape.push_back(r.first);
strides.push_back(r.second);
}
// We can call the contiguous reduction op for every weird way the input is
// structured in the rest of the axes.
if (strides.back() == 1) {
return ReductionPlan(GeneralContiguousReduce, shape, strides);
}
// Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous.
if (strides.back() > 1) {
int64_t size = 1;
bool have_expand = false;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
auto stride_i = x.strides()[i];
auto shape_i = x.shape(i);
if (stride_i == 0) {
if (shape_i == 1) {
continue;
}
have_expand = true;
break;
}
if (stride_i != size && shape_i != 1) {
break;
}
size *= shape_i;
break;
}
// In the case of an expanded dimension we are being conservative and
// require the smallest reduction stride to be smaller than the maximum row
// contiguous size. The reason is that we can't easily know if the reduced
// axis is before or after an expanded dimension.
if (size > strides.back() || (size == strides.back() && !have_expand)) {
return ReductionPlan(GeneralStridedReduce, shape, strides);
case Reduce::Max:
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
}
return ReductionPlan(GeneralReduce, shape, strides);
}
} // namespace mlx::core

View File

@@ -48,8 +48,186 @@ struct ReductionPlan {
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
// Helper for the ndimensional strided loop
// Should this be in utils?
void nd_loop(
std::function<void(int)> callback,
const Shape& shape,
const Strides& strides);
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
template <typename T, typename U, typename Op>
struct DefaultStridedReduce {
Op op;
DefaultStridedReduce(Op op_) : op(op_) {}
void operator()(const T* x, U* accumulator, int size, size_t stride) {
for (int i = 0; i < size; i++) {
U* moving_accumulator = accumulator;
for (int j = 0; j < stride; j++) {
op(moving_accumulator, *x);
moving_accumulator++;
x++;
}
}
}
};
template <typename T, typename U, typename Op>
struct DefaultContiguousReduce {
Op op;
DefaultContiguousReduce(Op op_) : op(op_) {}
void operator()(const T* x, U* accumulator, int size) {
while (size-- > 0) {
op(accumulator, *x);
x++;
}
}
};
template <typename T, typename U, typename OpS, typename OpC, typename Op>
void reduction_op(
const array& x,
array& out,
const std::vector<int>& axes,
U init,
OpS ops,
OpC opc,
Op op) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
ReductionPlan plan = get_reduction_plan(x, axes);
if (plan.type == ContiguousAllReduce) {
U* out_ptr = out.data<U>();
*out_ptr = init;
opc(x.data<T>(), out_ptr, x.size());
return;
}
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
*out_ptr = init;
opc(x_ptr, out_ptr, reduction_size);
}
return;
}
if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
int reduction_size = plan.shape.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
opc(x_ptr + offset, out_ptr, reduction_size);
}
} else {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
opc(x_ptr + offset + extra_offset, out_ptr, reduction_size);
},
plan.shape,
plan.strides);
}
}
return;
}
if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape.back();
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
ops(x_ptr, out_ptr, reduction_size, reduction_stride);
x_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
return;
}
if (plan.type == GeneralStridedReduce ||
plan.type == ContiguousStridedReduce) {
int reduction_size = plan.shape.back();
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride);
out_ptr += reduction_stride;
}
} else {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
ops(x_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride);
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
}
return;
}
if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); },
plan.shape,
plan.strides);
*out_ptr = val;
}
}
}
template <typename T, typename U, typename Op>
void reduction_op(
const array& x,
array& out,
const std::vector<int>& axes,
U init,
Op op) {
DefaultStridedReduce<T, U, Op> ops(op);
DefaultContiguousReduce<T, U, Op> opc(op);
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
}
} // namespace mlx::core

View File

@@ -0,0 +1,147 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/reduce.h"
namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
strides.erase(strides.begin() + a);
}
return std::make_pair(shape, strides);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
return ContiguousAllReduce;
}
// Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) {
// Merge consecutive axes
Shape shape = {x.shape(axes[0])};
Strides strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];
} else {
shape.push_back(x.shape(axes[i]));
strides.push_back(x.strides()[axes[i]]);
}
}
// Remove singleton axes from the plan
for (int i = shape.size() - 1; i >= 0; i--) {
if (shape[i] == 1) {
shape.erase(shape.begin() + i);
strides.erase(strides.begin() + i);
}
}
if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) {
return ReductionPlan(ContiguousStridedReduce, shape, strides);
}
}
// Let's check if we can optimize our access patterns
//
// 1. We have a reduction axis with stride 1. Simply call
// GeneralContiguousReduce and be done with it.
// 2. We have transpositions and we are not reducing over the axis with
// stride 1. However, we are reducing over an axis where everything is
// contiguous in memory to the right of that axis. We can call strided
// reduce and be done with it.
// 2. We have weird transpositions and expands. Copy the strides to the
// output, then call strided reduce.
// Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction.
std::vector<std::pair<int, int64_t>> reductions;
for (auto a : axes) {
if (x.shape(a) > 1) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
}
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
bool a_is_zero = a.second == 0;
bool b_is_zero = b.second == 0;
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
});
// Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis.
for (int i = reductions.size() - 1; i >= 1; i--) {
auto a = reductions[i];
auto b = reductions[i - 1];
// b.stride = a.shape * a.stride then a and b are contiguous
if (b.second == a.first * a.second) {
reductions.erase(reductions.begin() + i);
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
}
}
Shape shape;
Strides strides;
for (auto r : reductions) {
shape.push_back(r.first);
strides.push_back(r.second);
}
// We can call the contiguous reduction op for every weird way the input is
// structured in the rest of the axes.
if (strides.back() == 1) {
return ReductionPlan(GeneralContiguousReduce, shape, strides);
}
// Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous.
if (strides.back() > 1) {
int64_t size = 1;
bool have_expand = false;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
auto stride_i = x.strides()[i];
auto shape_i = x.shape(i);
if (stride_i == 0) {
if (shape_i == 1) {
continue;
}
have_expand = true;
break;
}
if (stride_i != size && shape_i != 1) {
break;
}
size *= shape_i;
}
// In the case of an expanded dimension we are being conservative and
// require the smallest reduction stride to be smaller than the maximum row
// contiguous size. The reason is that we can't easily know if the reduced
// axis is before or after an expanded dimension.
if (size > strides.back() || (size == strides.back() && !have_expand)) {
return ReductionPlan(GeneralStridedReduce, shape, strides);
}
}
return ReductionPlan(GeneralReduce, shape, strides);
}
} // namespace mlx::core

325
mlx/backend/common/scan.cpp Normal file
View File

@@ -0,0 +1,325 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, typename U, typename Op>
struct DefaultContiguousScan {
Op op;
U init;
DefaultContiguousScan(Op op_, U init_) : op(op_), init(init_) {}
void operator()(
const T* input,
U* output,
int count,
int stride,
bool reverse,
bool inclusive) {
if (!reverse) {
if (inclusive) {
for (int i = 0; i < count; i++) {
*output = *input;
for (int j = 1; j < stride; j++) {
input++;
output++;
op(output, output - 1, input);
}
output++;
input++;
}
} else {
for (int i = 0; i < count; i++) {
*output = init;
for (int j = 1; j < stride; j++) {
op(output + 1, output, input);
input++;
output++;
}
output++;
input++;
}
}
} else {
if (inclusive) {
for (int i = 0; i < count; i++) {
output += stride - 1;
input += stride - 1;
*output = *input;
for (int j = 1; j < stride; j++) {
input--;
output--;
op(output, output + 1, input);
}
output += stride;
input += stride;
}
} else {
for (int i = 0; i < count; i++) {
output += stride - 1;
input += stride - 1;
*output = init;
for (int j = 1; j < stride; j++) {
op(output - 1, output, input);
input--;
output--;
}
output += stride;
input += stride;
}
}
}
}
};
template <typename T, typename U, typename Op>
struct DefaultStridedScan {
Op op;
U init;
DefaultStridedScan(Op op_, U init_) : op(op_), init(init_) {}
void operator()(
const T* input,
U* output,
int count,
int size,
int stride,
bool reverse,
bool inclusive) {
// TODO: Vectorize the following naive implementation
if (!reverse) {
if (inclusive) {
for (int i = 0; i < count; i++) {
std::copy(input, input + stride, output);
output += stride;
input += stride;
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
op(output, output - stride, input);
output++;
input++;
}
}
}
} else {
for (int i = 0; i < count; i++) {
std::fill(output, output + stride, init);
output += stride;
input += stride;
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
op(output, output - stride, input - stride);
output++;
input++;
}
}
}
}
} else {
if (inclusive) {
for (int i = 0; i < count; i++) {
output += (size - 1) * stride;
input += (size - 1) * stride;
std::copy(input, input + stride, output);
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
output--;
input--;
op(output, output + stride, input);
}
}
output += size * stride;
input += size * stride;
}
} else {
for (int i = 0; i < count; i++) {
output += (size - 1) * stride;
input += (size - 1) * stride;
std::fill(output, output + stride, init);
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
output--;
input--;
op(output, output + stride, input + stride);
}
}
output += size * stride;
input += size * stride;
}
}
}
}
};
template <typename T, typename U, typename OpCS, typename OpSS>
void scan_op(
OpCS opcs,
OpSS opss,
const array& input,
array& output,
int axis,
bool reverse,
bool inclusive) {
output.set_data(allocator::malloc_or_wait(output.nbytes()));
if (input.flags().row_contiguous) {
if (input.strides()[axis] == 1) {
opcs(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis),
input.shape(axis),
reverse,
inclusive);
} else {
opss(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis) / input.strides()[axis],
input.shape(axis),
input.strides()[axis],
reverse,
inclusive);
}
} else {
throw std::runtime_error("Scan op supports only contiguous inputs");
}
}
template <typename T, typename U>
void scan_dispatch(
Scan::ReduceType rtype,
const array& input,
array& output,
int axis,
bool reverse,
bool inclusive) {
switch (rtype) {
case Scan::Sum: {
auto op = [](U* o, const U* y, const T* x) { *o = *y + *x; };
auto init = static_cast<U>(0);
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
break;
}
case Scan::Prod: {
auto op = [](U* o, const U* y, const T* x) { *o = *y * (*x); };
auto init = static_cast<U>(1);
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
break;
}
case Scan::Min: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; };
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
break;
}
case Scan::Max: {
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
break;
}
}
}
} // namespace
void Scan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Ensure contiguity
auto in = inputs[0];
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General);
in = arr_copy;
}
switch (in.dtype()) {
case bool_: {
// We could do a full dtype x dtype switch but this is the only case
// where we accumulate in a different type, for now.
//
// TODO: If we add the option to accumulate floats in higher precision
// floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
} else {
scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
}
break;
}
case uint8:
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
}
}
} // namespace mlx::core

View File

@@ -2,8 +2,7 @@
#include <cassert>
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/ternary.h"
#include "mlx/backend/common/ternary.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -51,9 +50,6 @@ void select_op(
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op);
break;
case float64:
ternary_op<bool, double, double, double>(a, b, c, out, op);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
break;
@@ -65,7 +61,7 @@ void select_op(
} // namespace
void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
void Select::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
const auto& condition = inputs[0];
const auto& a = inputs[1];

View File

@@ -35,29 +35,4 @@ void shared_buffer_slice(
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
}
void slice(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
// Calculate out strides, initial offset
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
int64_t data_end = 1;
for (int i = 0; i < start_indices.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
}
}
// data_end can be -1
size_t data_size =
data_end < 0 ? (data_offset - data_end) : (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}
} // namespace mlx::core

View File

@@ -11,10 +11,11 @@ std::tuple<int64_t, Strides> prepare_slice(
const Shape& start_indices,
const Shape& strides);
void slice(
void shared_buffer_slice(
const array& in,
array& out,
const Shape& start_indices,
const Shape& strides);
const Strides& out_strides,
size_t data_offset,
size_t data_size,
array& out);
} // namespace mlx::core

View File

@@ -3,109 +3,62 @@
#include <cassert>
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void softmax(const array& in, array& out) {
constexpr bool same_t = std::is_same_v<T, AccT>;
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
int N = in.shape().back();
int M = in.data_size() / N;
const T* current_in_ptr;
T* current_out_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
AccT maximum = *current_in_ptr;
for (int j = 0; j < N; j++, current_in_ptr++) {
maximum = (maximum < *current_in_ptr) ? static_cast<AccT>(*current_in_ptr)
: maximum;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
AccT normalizer = 0;
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
if constexpr (same_t) {
store(current_out_ptr, vexp);
for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) {
AccT expv = std::exp(*current_in_ptr - maximum);
normalizer += expv;
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr = expv;
}
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
current_out_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
if constexpr (same_t) {
*current_out_ptr = _exp;
}
normalizer += _exp;
current_in_ptr++;
current_out_ptr++;
}
normalizer = 1 / normalizer;
// Normalize
current_out_ptr = out_ptr;
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
if constexpr (same_t) {
store(
current_out_ptr,
Simd<T, N>(load<T, N>(current_out_ptr) * normalizer));
} else {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum) * normalizer;
store(current_out_ptr, Simd<T, N>(vexp));
current_in_ptr += N;
}
current_out_ptr += N;
s -= N;
}
while (s-- > 0) {
if constexpr (same_t) {
current_out_ptr = out_ptr;
for (int j = 0; j < N; j++, current_out_ptr++) {
if constexpr (std::is_same<T, AccT>::value) {
*current_out_ptr *= normalizer;
} else {
AccT _exp = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(_exp * normalizer);
auto v = std::exp(*current_in_ptr - maximum);
*current_out_ptr = static_cast<T>(v * normalizer);
current_in_ptr++;
}
current_out_ptr++;
}
}
}
} // namespace
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
void Softmax::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
@@ -144,7 +97,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
case int16:
case int32:
case int64:
throw std::runtime_error(
throw std::invalid_argument(
"Softmax is defined only for floating point types");
break;
case float32:
@@ -164,9 +117,6 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
softmax<bfloat16_t, bfloat16_t>(in, out);
}
break;
case float64:
softmax<double, double>(in, out);
break;
case complex64:
throw std::invalid_argument(
"[Softmax] Not yet implemented for complex64");

View File

@@ -5,8 +5,8 @@
#include <cmath>
#include <numeric>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/primitives.h"
@@ -287,7 +287,7 @@ void argpartition(const array& in, array& out, int axis, int kth) {
} // namespace
void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
void ArgSort::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -312,8 +312,6 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
return argsort<int64_t>(in, out, axis_);
case float32:
return argsort<float>(in, out, axis_);
case float64:
return argsort<double>(in, out, axis_);
case float16:
return argsort<float16_t>(in, out, axis_);
case bfloat16:
@@ -323,7 +321,7 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
void Sort::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -348,8 +346,6 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
return sort<int64_t>(in, out, axis_);
case float32:
return sort<float>(in, out, axis_);
case float64:
return sort<double>(in, out, axis_);
case float16:
return sort<float16_t>(in, out, axis_);
case bfloat16:
@@ -359,7 +355,7 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
void ArgPartition::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -384,8 +380,6 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
return argpartition<int64_t>(in, out, axis_, kth_);
case float32:
return argpartition<float>(in, out, axis_, kth_);
case float64:
return argpartition<double>(in, out, axis_, kth_);
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16:
@@ -395,7 +389,7 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
void Partition::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -420,8 +414,6 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
return partition<int64_t>(in, out, axis_, kth_);
case float32:
return partition<float>(in, out, axis_, kth_);
case float64:
return partition<double>(in, out, axis_, kth_);
case float16:
return partition<float16_t>(in, out, axis_, kth_);
case bfloat16:

View File

@@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -137,9 +137,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
}
}
void SVD::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[SVD::eval] only supports float32.");
}

View File

@@ -3,10 +3,12 @@
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
namespace {
// TODO: Add support for more combinations of input types.
enum class TernaryOpType {
ScalarScalarScalar,
@@ -14,7 +16,7 @@ enum class TernaryOpType {
General,
};
inline TernaryOpType
TernaryOpType
get_ternary_op_type(const array& a, const array& b, const array& c) {
TernaryOpType topt;
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
@@ -31,7 +33,7 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
return topt;
}
inline void set_ternary_op_output_data(
void set_ternary_op_output_data(
const array& a,
const array& b,
const array& c,
@@ -74,5 +76,152 @@ inline void set_ternary_op_output_data(
break;
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
void ternary_op_dims(
const T1* a,
const T2* b,
const T3* c,
U* out,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& c_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_c = c_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
a,
b,
c,
out,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
axis + 1);
} else {
*out = op(*a, *b, *c);
}
a += stride_a;
b += stride_b;
c += stride_c;
out += stride_out;
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dispatch_dims(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& c_strides = strides[2];
const auto& out_strides = strides[3];
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<T3>();
int ndim = shape.size();
switch (ndim) {
case 1:
ternary_op_dims<T1, T2, T3, U, Op, 1>(
a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
case 2:
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
}
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
ContiguousIterator c_it(shape, c_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
c_ptr + c_it.loc,
out_ptr + elem,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
c_it.step();
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
// The full computation is scalar-scalar-scalar so we call the base op once.
if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
} else if (topt == TernaryOpType::VectorVectorVector) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
} else {
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
}
}
} // namespace
} // namespace mlx::core

View File

@@ -1,6 +1,6 @@
// Copyright © 2023 Apple Inc.
#include "mlx/backend/cpu/threefry.h"
#include "mlx/backend/common/threefry.h"
namespace mlx::core::random {

View File

@@ -5,11 +5,12 @@
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
void set_unary_output_data(const array& in, array& out) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
@@ -37,19 +38,8 @@ void unary_op(const array& a, array& out, Op op) {
if (a.flags().contiguous) {
set_unary_output_data(a, out);
U* dst = out.data<U>();
constexpr int N = simd::max_size<T>;
size_t size = a.data_size();
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a_ptr)));
size -= N;
a_ptr += N;
dst += N;
}
while (size > 0) {
*dst = op(*a_ptr);
size--;
dst++;
a_ptr++;
for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -104,9 +94,6 @@ void unary(const array& a, array& out, Op op) {
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
@@ -128,9 +115,6 @@ void unary_fp(const array& a, array& out, Op op) {
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;
@@ -141,38 +125,6 @@ void unary_fp(const array& a, array& out, Op op) {
}
}
template <typename Op>
void unary_int(const array& a, array& out, Op op) {
switch (out.dtype()) {
case uint8:
unary_op<uint8_t>(a, out, op);
break;
case uint16:
unary_op<uint16_t>(a, out, op);
break;
case uint32:
unary_op<uint32_t>(a, out, op);
break;
case uint64:
unary_op<uint64_t>(a, out, op);
break;
case int8:
unary_op<int8_t>(a, out, op);
break;
case int16:
unary_op<int16_t>(a, out, op);
break;
case int32:
unary_op<int32_t>(a, out, op);
break;
case int64:
unary_op<int64_t>(a, out, op);
break;
default:
std::ostringstream err;
err << "[unary_int] Does not support " << out.dtype();
throw std::runtime_error(err.str());
}
}
} // namespace
} // namespace mlx::core

View File

@@ -1,82 +0,0 @@
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE)
else()
set(COMPILER ${CMAKE_CXX_COMPILER})
endif()
set(COMPILE_DEPS
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
simd/simd.h
simd/base_simd.h
simd/math.h
simd/type.h
unary_ops.h
binary_ops.h)
if(MSVC)
set(SHELL_EXT ps1)
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
else()
set(SHELL_EXT sh)
set(SHELL_CMD bash)
endif()
add_custom_command(
OUTPUT compiled_preamble.cpp
COMMAND
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
DEPENDS make_compiled_preamble.${SHELL_EXT} compiled_preamble.h
${COMPILE_DEPS})
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
add_dependencies(mlx cpu_compiled_preamble)
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cblas.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
if(MLX_BUILD_ACCELERATE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
endif()
if(IOS)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../no_cpu/compiled.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)
endif()

View File

@@ -1,373 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
template <typename Op>
struct VectorScalar {
Op op;
VectorScalar(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *b;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
dst += N;
a += N;
size -= N;
}
while (size-- > 0) {
*dst = op(*a, scalar);
dst++;
a++;
}
}
};
template <typename Op>
struct ScalarVector {
Op op;
ScalarVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
T scalar = *a;
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
dst += N;
b += N;
size -= N;
}
while (size-- > 0) {
*dst = op(scalar, *b);
dst++;
b++;
}
}
};
template <typename Op>
struct VectorVector {
Op op;
VectorVector(Op op_) : op(op_) {}
template <typename T, typename U>
void operator()(const T* a, const T* b, U* dst, int size) {
constexpr int N = simd::max_size<T>;
while (size >= N) {
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
dst += N;
a += N;
b += N;
size -= N;
}
while (size-- > 0) {
*dst = op(*a, *b);
dst++;
a++;
b++;
}
}
};
template <typename T, typename U, typename Op, int D, bool Strided>
void binary_op_dims(
const T* a,
const T* b,
U* out,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
binary_op_dims<T, U, Op, D - 1, Strided>(
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
} else {
if constexpr (Strided) {
op(a, b, out, stride_out);
} else {
*out = op(*a, *b);
}
}
out += stride_out;
a += stride_a;
b += stride_b;
}
}
template <typename T, typename U, bool Strided, typename Op>
void binary_op_dispatch_dims(
const array& a,
const array& b,
array& out,
Op op,
int dim,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
switch (dim) {
case 1:
binary_op_dims<T, U, Op, 1, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 2:
binary_op_dims<T, U, Op, 2, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
case 3:
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr,
b_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
out_strides,
0);
return;
}
ContiguousIterator a_it(shape, a_strides, dim - 3);
ContiguousIterator b_it(shape, b_strides, dim - 3);
auto stride = out_strides[dim - 4];
for (int64_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
out_ptr + elem,
op,
shape,
a_strides,
b_strides,
out_strides,
dim - 3);
a_it.step();
b_it.step();
}
}
template <typename T, typename U, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == BinaryOpType::ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == BinaryOpType::ScalarVector) {
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == BinaryOpType::VectorScalar) {
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == BinaryOpType::VectorVector) {
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
// General computation so let's try to optimize
auto [new_shape, new_strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), out.strides()});
const auto& a_strides = new_strides[0];
const auto& b_strides = new_strides[1];
const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
}
return d + 1;
};
auto a_rc_dim = leftmost_rc_dim(a_strides);
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
}
return d + 1;
};
auto a_s_dim = leftmost_s_dim(a_strides);
auto b_s_dim = leftmost_s_dim(b_strides);
auto ndim = new_shape.size();
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = BinaryOpType::ScalarVector;
dim = d;
}
// Can be sure dim > 0 since otherwise we would have used one of the fully
// contiguous methods above. Except for the case that the flags do not
// correspond to the underlying contiguity.
if (dim == 0 || strides[dim - 1] < 16) {
bopt = BinaryOpType::General;
dim = ndim;
}
switch (bopt) {
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
VectorVector{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
VectorScalar{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break;
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U, true>(
a,
b,
out,
ScalarVector{op},
dim,
new_shape,
a_strides,
b_strides,
strides);
break;
default:
binary_op_dispatch_dims<T, U, false>(
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
break;
}
}
template <typename T, typename Op>
void binary_op(const array& a, const array& b, array& out, Op op) {
binary_op<T, T>(a, b, out, op);
}
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
break;
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
case float16:
binary_op<float16_t>(a, b, out, op);
break;
case float32:
binary_op<float>(a, b, out, op);
break;
case float64:
binary_op<double>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, op);
break;
case complex64:
binary_op<complex64_t>(a, b, out, op);
break;
}
}
} // namespace mlx::core

View File

@@ -1,98 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core::detail {
using namespace mlx::core::simd;
#define BINARY_SINGLE() \
template <typename T> \
T operator()(T x, T y) { \
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
}
#define DEFAULT_BINARY_OP(Op, op) \
struct Op { \
template <int N, typename T> \
Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
return op(x, y); \
} \
BINARY_SINGLE() \
};
DEFAULT_BINARY_OP(Add, operator+)
DEFAULT_BINARY_OP(ArcTan2, atan2)
DEFAULT_BINARY_OP(Divide, operator/)
DEFAULT_BINARY_OP(Multiply, operator*)
DEFAULT_BINARY_OP(Subtract, operator-)
DEFAULT_BINARY_OP(LogicalAnd, operator&&)
DEFAULT_BINARY_OP(LogicalOr, operator||)
DEFAULT_BINARY_OP(BitwiseAnd, operator&)
DEFAULT_BINARY_OP(BitwiseOr, operator|)
DEFAULT_BINARY_OP(BitwiseXor, operator^)
DEFAULT_BINARY_OP(LeftShift, operator<<)
DEFAULT_BINARY_OP(RightShift, operator>>)
DEFAULT_BINARY_OP(Remainder, remainder)
DEFAULT_BINARY_OP(Maximum, maximum)
DEFAULT_BINARY_OP(Minimum, minimum)
DEFAULT_BINARY_OP(Power, pow)
#define DEFAULT_BOOL_OP(Op, op) \
struct Op { \
template <int N, typename T> \
Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
return op(x, y); \
} \
template <typename T> \
bool operator()(T x, T y) { \
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
} \
};
DEFAULT_BOOL_OP(Equal, operator==)
DEFAULT_BOOL_OP(Greater, operator>)
DEFAULT_BOOL_OP(GreaterEqual, operator>=)
DEFAULT_BOOL_OP(Less, operator<)
DEFAULT_BOOL_OP(LessEqual, operator<=)
DEFAULT_BOOL_OP(NotEqual, operator!=)
struct NaNEqual {
template <int N, typename T>
Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) {
return x == y || (isnan(x) && isnan(y));
}
template <typename T>
bool operator()(T x, T y) {
return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value;
}
};
struct LogAddExp {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) {
auto maxval = maximum(x, y);
auto minval = minimum(x, y);
auto mask = minval == -inf || maxval == inf;
auto out = maxval + log1p(exp(minval - maxval));
return select(mask, Simd<T, N>(maxval), Simd<T, N>(out));
}
BINARY_SINGLE()
};
struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return (*this)(Simd<bool, 1>(condition), Simd<T, 1>(x), Simd<T, 1>(y))
.value;
}
template <int N, typename T>
Simd<T, N> operator()(Simd<bool, N> condition, Simd<T, N> x, Simd<T, N> y) {
return select(condition, x, y);
}
};
} // namespace mlx::core::detail

View File

@@ -1,24 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
void copy_inplace(
const array& src,
array& dst,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core

View File

@@ -1,20 +0,0 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
template <typename T>
void matmul(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta);
} // namespace mlx::core

View File

@@ -1,157 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <Accelerate/Accelerate.h>
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/dtype.h"
namespace mlx::core {
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
uint32_t size_bits = size_of(mlx_dtype) * 8;
switch (kindof(mlx_dtype)) {
case Dtype::Kind::b:
return BNNSDataTypeBoolean;
case Dtype::Kind::u:
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
case Dtype::Kind::i:
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
case Dtype::Kind::f:
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
case Dtype::Kind::V:
return BNNSDataTypeBFloat16;
case Dtype::Kind::c:
throw std::invalid_argument("BNNS does not support complex types");
}
}
void matmul_bnns(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,
/* bool transA = */ a_transposed,
/* bool transB = */ b_transposed,
/* bool quadratic = */ false,
/* bool a_is_weights = */ false,
/* bool b_is_weights = */ false,
/* BNNSNDArrayDescriptor iA_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, lda, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
/* BNNSNDArrayDescriptor iB_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, ldb, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
/* BNNSNDArrayDescriptor o_desc = */
BNNSNDArrayDescriptor{
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
{N, M, 0, 0, 0, 0, 0, 0},
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
{1, N, 0, 0, 0, 0, 0, 0},
/* void * _Nullable data = */ nullptr,
/* BNNSDataType data_type = */ bnns_dtype,
/* void * _Nullable table_data = */ nullptr,
/* BNNSDataType table_data_type = */ bnns_dtype,
/* float data_scale = */ 1.0,
/* float data_bias = */ 0.0,
},
};
auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
BNNSFilterApplyTwoInput(
bnns_filter,
a.data<uint8_t>() +
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
b.data<uint8_t>() +
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
out.data<uint8_t>() + M * N * i * out.itemsize());
}
BNNSFilterDestroy(bnns_filter);
}
template <>
void matmul<float16_t>(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta) {
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
}
template <>
void matmul<bfloat16_t>(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta) {
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
}
} // namespace mlx::core

View File

@@ -1,79 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/lapack.h"
namespace mlx::core {
template <>
void matmul<float>(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
template <>
void matmul<double>(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_dgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
alpha, // alpha
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
out.data<double>() + M * N * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core

View File

@@ -1,21 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const array&,
const array&,
array&,
bool,
bool,
size_t,
size_t,
float,
float) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,21 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const array&,
const array&,
array&,
bool,
bool,
size_t,
size_t,
float,
float) {
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,88 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
void lu_factor_impl(
const array& a,
array& lu,
array& pivots,
array& row_indices) {
int M = a.shape(-2);
int N = a.shape(-1);
// Copy a into lu and make it col contiguous
auto ndim = lu.ndim();
auto flags = lu.flags();
flags.col_contiguous = ndim == 2;
flags.row_contiguous = false;
flags.contiguous = true;
auto strides = lu.strides();
strides[ndim - 1] = M;
strides[ndim - 2] = 1;
lu.set_data(
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace(
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
auto a_ptr = lu.data<float>();
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
auto pivots_ptr = pivots.data<uint32_t>();
auto row_indices_ptr = row_indices.data<uint32_t>();
int info;
size_t num_matrices = a.size() / (M * N);
for (size_t i = 0; i < num_matrices; ++i) {
// Compute LU factorization of A
MLX_LAPACK_FUNC(sgetrf)
(/* m */ &M,
/* n */ &N,
/* a */ a_ptr,
/* lda */ &M,
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
/* info */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
<< ((info > 0) ? " because matrix is singular"
: " because argument had an illegal value");
throw std::runtime_error(ss.str());
}
// Subtract 1 to get 0-based index
for (int j = 0; j < pivots.shape(-1); ++j) {
pivots_ptr[j]--;
row_indices_ptr[j] = j;
}
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
auto piv = pivots_ptr[j];
auto t1 = row_indices_ptr[piv];
auto t2 = row_indices_ptr[j];
row_indices_ptr[j] = t1;
row_indices_ptr[piv] = t2;
}
// Advance pointers to the next matrix
a_ptr += M * N;
pivots_ptr += pivots.shape(-1);
row_indices_ptr += pivots.shape(-1);
}
}
void LUF::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
}
} // namespace mlx::core

View File

@@ -1,82 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/primitives.h"
namespace mlx::core {
void matmul_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (out.dtype() == float32) {
matmul<float>(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
} else if (out.dtype() == float16) {
matmul<float16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
} else if (out.dtype() == bfloat16) {
matmul<bfloat16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
} else if (out.dtype() == float64) {
matmul<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
} else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
}
}
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (inputs[0].shape(-1) == 0) {
std::memset(out.data<void>(), 0, out.nbytes());
return;
}
return matmul_general(inputs[0], inputs[1], out);
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy(c, out, ctype);
return matmul_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core

View File

@@ -1,389 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/backend/common/load.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/arange.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/threefry.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_inplace(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
int64_t compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes) {
auto compute_offset = [&strides, &axes](const auto* indices) {
int64_t offset = 0;
for (int i = 0; i < axes.size(); ++i) {
offset += indices[i] * strides[axes[i]];
}
return offset;
};
switch (indices.dtype()) {
case int8:
case uint8:
return compute_offset(indices.data<uint8_t>());
case int16:
case uint16:
return compute_offset(indices.data<uint16_t>());
case int32:
case uint32:
return compute_offset(indices.data<uint32_t>());
case int64:
case uint64:
return compute_offset(indices.data<uint64_t>());
default:
throw std::runtime_error("Invalid indices type.");
}
}
void AsStrided::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Broadcast::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void BroadcastAxes::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Copy::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void CustomTransforms::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Depends::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void ExpandDims::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void NumberOfElements::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Slice::eval_cpu(const std::vector<array>& inputs, array& out) {
slice(inputs[0], out, start_indices_, strides_);
}
void Split::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Squeeze::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void StopGradient::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
arange(inputs, out, start_, step_);
}
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
}
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
std::vector<int> sizes;
sizes.push_back(0);
for (auto& p : inputs) {
sizes.push_back(p.shape(axis_));
}
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto strides = out.strides();
auto flags = out.flags();
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral);
}
}
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
}
}
void Flatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Unflatten::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
assert(in.dtype() == out.dtype());
CopyType ctype;
if (in.data_size() == 1) {
ctype = CopyType::Scalar;
} else if (in.flags().contiguous) {
ctype = CopyType::Vector;
} else {
ctype = CopyType::General;
}
copy(in, out, ctype);
}
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
load(out, offset_, reader_, swap_endianness_);
}
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];
auto& val = inputs[1];
// Padding value must be a scalar
assert(val.size() == 1);
// Padding value, input and output must be of the same type
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
// Fill output with val
copy(val, out, CopyType::Scalar);
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes_.size(); i++) {
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
data_offset += out.strides()[ax] * low_pad_size_[i];
}
// Extract slice from output where input will be pasted
array out_slice(in.shape(), out.dtype(), nullptr, {});
out_slice.copy_shared_buffer(
out, out.strides(), out.flags(), out_slice.size(), data_offset);
// Copy input values into the slice
copy_inplace(in, out_slice, CopyType::GeneralGeneral);
}
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
size_t num_keys = keys.size() / 2;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>();
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);
// Get ith key
auto kidx = 2 * i;
auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides());
auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides());
auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]);
std::pair<uintptr_t, uintptr_t> count{0, half_size + !even};
for (; count.first + 1 < half_size; count.first++, count.second++) {
std::tie(ptr[count.first], ptr[count.second]) =
random::threefry2x32_hash(key, count);
}
if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) {
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
}
if (!even) {
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
}
}
}
void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
reshape(inputs[0], out);
}
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_);
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const Shape& data_shape = */ out.shape(),
/* const Strides& i_strides = */ in.strides(),
/* const Strides& o_strides = */ out.strides(),
/* int64_t i_offset = */ i_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
// Copy or move src to dst
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_);
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out.strides(),
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ o_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] =
prepare_slice(out, start_indices_, strides_);
// Do copy
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd.strides(),
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || (obytes < ibytes && in.strides().back() == 1) ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in);
copy_inplace(in_tmp, tmp, CopyType::General);
} else {
copy_inplace(in, tmp, CopyType::General);
}
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@@ -1,562 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <functional>
#include <limits>
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
namespace mlx::core {
template <typename U>
struct Limits {
static const U max;
static const U min;
};
#define instantiate_default_limit(type) \
template <> \
struct Limits<type> { \
static constexpr type max = std::numeric_limits<type>::max(); \
static constexpr type min = std::numeric_limits<type>::min(); \
};
instantiate_default_limit(uint8_t);
instantiate_default_limit(uint16_t);
instantiate_default_limit(uint32_t);
instantiate_default_limit(uint64_t);
instantiate_default_limit(int8_t);
instantiate_default_limit(int16_t);
instantiate_default_limit(int32_t);
instantiate_default_limit(int64_t);
#define instantiate_float_limit(type) \
template <> \
struct Limits<type> { \
static const type max; \
static const type min; \
};
instantiate_float_limit(float16_t);
instantiate_float_limit(bfloat16_t);
instantiate_float_limit(float);
instantiate_float_limit(double);
instantiate_float_limit(complex64_t);
template <>
struct Limits<bool> {
static constexpr bool max = true;
static constexpr bool min = false;
};
const float Limits<float>::max = std::numeric_limits<float>::infinity();
const float Limits<float>::min = -std::numeric_limits<float>::infinity();
const bfloat16_t Limits<bfloat16_t>::max =
std::numeric_limits<float>::infinity();
const bfloat16_t Limits<bfloat16_t>::min =
-std::numeric_limits<float>::infinity();
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
const float16_t Limits<float16_t>::min =
-std::numeric_limits<float>::infinity();
const double Limits<double>::max = std::numeric_limits<double>::infinity();
const double Limits<double>::min = -std::numeric_limits<double>::infinity();
const complex64_t Limits<complex64_t>::max =
std::numeric_limits<float>::infinity();
const complex64_t Limits<complex64_t>::min =
-std::numeric_limits<float>::infinity();
template <typename T, typename U, typename Op>
void strided_reduce(
const T* x,
U* accumulator,
int size,
size_t stride,
Op op) {
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
for (int i = 0; i < size; i++) {
U* moving_accumulator = accumulator;
auto s = stride;
while (s >= N) {
auto acc = simd::load<U, N>(moving_accumulator);
auto v = simd::Simd<U, N>(simd::load<T, N>(x));
simd::store<U, N>(moving_accumulator, op(acc, v));
moving_accumulator += N;
x += N;
s -= N;
}
while (s-- > 0) {
*moving_accumulator = op(*moving_accumulator, *x);
moving_accumulator++;
x++;
}
}
};
template <typename T, typename U, typename Op>
void contiguous_reduce(const T* x, U* accumulator, int size, Op op, U init) {
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
simd::Simd<U, N> accumulator_v(init);
while (size >= N) {
accumulator_v = op(accumulator_v, simd::Simd<U, N>(simd::load<T, N>(x)));
x += N;
size -= N;
}
*accumulator = op(*accumulator, op(accumulator_v));
while (size-- > 0) {
*accumulator = op(*accumulator, *x);
x++;
}
}
// Helper for the ndimensional strided loop
void nd_loop(
std::function<void(int)> callback,
const Shape& shape,
const Strides& strides) {
std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) {
auto size = shape[dim];
auto stride = strides[dim];
for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride);
}
} else {
auto size = shape[dim];
auto stride = strides[dim];
for (int i = 0; i < size; i++) {
callback(offset + i * stride);
}
}
};
loop_inner(0, 0);
}
template <typename T, typename U, typename Op>
void reduction_op(
const array& x,
array& out,
const std::vector<int>& axes,
U init,
Op op) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
ReductionPlan plan = get_reduction_plan(x, axes);
if (plan.type == ContiguousAllReduce) {
U* out_ptr = out.data<U>();
*out_ptr = init;
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
return;
}
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
*out_ptr = init;
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
}
return;
}
if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) {
int reduction_size = plan.shape.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
}
} else {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
x_ptr + offset + extra_offset,
out_ptr,
reduction_size,
op,
init);
},
plan.shape,
plan.strides);
}
}
return;
}
if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape.back();
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
x_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
return;
}
if (plan.type == GeneralStridedReduce ||
plan.type == ContiguousStridedReduce) {
int reduction_size = plan.shape.back();
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
out_ptr += reduction_stride;
}
} else {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
strided_reduce(
x_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
op);
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
}
return;
}
if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) {
val = op(val, *(x_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
*out_ptr = val;
}
}
}
struct AndReduce {
template <typename T>
bool operator()(bool x, T y) {
return x & (y != 0);
}
bool operator()(bool x, bool y) {
return x & y;
}
template <int N, typename T>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
return x & (y != 0);
};
template <int N>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
return x & y;
};
template <int N, typename T>
bool operator()(simd::Simd<T, N> x) {
return simd::all(x);
};
};
struct OrReduce {
template <typename T>
bool operator()(bool x, T y) {
return x | (y != 0);
}
bool operator()(bool x, bool y) {
return x | y;
}
template <int N, typename T>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<T, N> x) {
return x | (y != 0);
};
template <int N>
simd::Simd<bool, N> operator()(simd::Simd<bool, N> y, simd::Simd<bool, N> x) {
return x | y;
};
template <int N, typename T>
bool operator()(simd::Simd<T, N> x) {
return simd::any(x);
};
};
struct MaxReduce {
template <typename T>
T operator()(T y, T x) {
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
};
template <int N, typename T>
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
return simd::maximum(x, y);
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::max(x);
};
};
struct MinReduce {
template <typename T>
T operator()(T y, T x) {
return (*this)(simd::Simd<T, 1>(x), simd::Simd<T, 1>(y)).value;
};
template <int N, typename T>
simd::Simd<T, N> operator()(simd::Simd<T, N> y, simd::Simd<T, N> x) {
return simd::minimum(x, y);
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::min(x);
};
};
struct SumReduce {
template <typename T, typename U>
U operator()(U y, T x) {
return x + y;
};
template <int N, typename T, typename U>
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
return y + x;
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::sum(x);
};
};
struct ProdReduce {
template <typename T, typename U>
U operator()(U y, T x) {
return x * y;
};
template <int N, typename T, typename U>
simd::Simd<U, N> operator()(simd::Simd<U, N> y, simd::Simd<T, N> x) {
return x * y;
};
template <int N, typename T>
T operator()(simd::Simd<T, N> x) {
return simd::prod(x);
};
};
template <typename InT>
void reduce_dispatch_and_or(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::And) {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
} else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
}
}
template <typename InT>
void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
} else {
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
}
} else {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
} else {
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
}
}
}
template <typename InT>
void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
}
}
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
case Reduce::Max:
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
}
}
} // namespace mlx::core

View File

@@ -1,316 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, typename U, typename Op>
void contiguous_scan(
const T* input,
U* output,
int count,
int stride,
bool reverse,
bool inclusive,
const Op& op,
U init) {
if (!reverse) {
if (inclusive) {
for (int i = 0; i < count; i++) {
*output = *input;
for (int j = 1; j < stride; j++) {
input++;
output++;
*output = op(*(output - 1), *input);
}
output++;
input++;
}
} else {
for (int i = 0; i < count; i++) {
*output = init;
for (int j = 1; j < stride; j++) {
*(output + 1) = op(*output, *input);
input++;
output++;
}
output++;
input++;
}
}
} else {
if (inclusive) {
for (int i = 0; i < count; i++) {
output += stride - 1;
input += stride - 1;
*output = *input;
for (int j = 1; j < stride; j++) {
input--;
output--;
*output = op(*(output + 1), *input);
}
output += stride;
input += stride;
}
} else {
for (int i = 0; i < count; i++) {
output += stride - 1;
input += stride - 1;
*output = init;
for (int j = 1; j < stride; j++) {
*(output - 1) = op(*output, *input);
input--;
output--;
}
output += stride;
input += stride;
}
}
}
};
template <typename T, typename U, typename Op>
void strided_scan(
const T* input,
U* output,
int count,
int size,
int stride,
bool reverse,
bool inclusive,
const Op& op,
U init) {
// TODO: Vectorize the following naive implementation
if (!reverse) {
if (inclusive) {
for (int i = 0; i < count; i++) {
std::copy(input, input + stride, output);
output += stride;
input += stride;
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
*output = op(*(output - stride), *input);
output++;
input++;
}
}
}
} else {
for (int i = 0; i < count; i++) {
std::fill(output, output + stride, init);
output += stride;
input += stride;
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
*output = op(*(output - stride), *(input - stride));
output++;
input++;
}
}
}
}
} else {
if (inclusive) {
for (int i = 0; i < count; i++) {
output += (size - 1) * stride;
input += (size - 1) * stride;
std::copy(input, input + stride, output);
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
output--;
input--;
*output = op(*(output + stride), *input);
}
}
output += size * stride;
input += size * stride;
}
} else {
for (int i = 0; i < count; i++) {
output += (size - 1) * stride;
input += (size - 1) * stride;
std::fill(output, output + stride, init);
for (int j = 1; j < size; j++) {
for (int k = 0; k < stride; k++) {
output--;
input--;
*output = op(*(output + stride), *(input + stride));
}
}
output += size * stride;
input += size * stride;
}
}
}
};
template <typename T, typename U, typename Op>
void scan_op(
const array& input,
array& output,
int axis,
bool reverse,
bool inclusive,
const Op& op,
U init) {
output.set_data(allocator::malloc_or_wait(output.nbytes()));
if (input.flags().row_contiguous) {
if (input.strides()[axis] == 1) {
contiguous_scan(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis),
input.shape(axis),
reverse,
inclusive,
op,
init);
} else {
strided_scan(
input.data<T>(),
output.data<U>(),
input.size() / input.shape(axis) / input.strides()[axis],
input.shape(axis),
input.strides()[axis],
reverse,
inclusive,
op,
init);
}
} else {
throw std::runtime_error("Scan op supports only contiguous inputs");
}
}
template <typename T, typename U>
void scan_dispatch(
Scan::ReduceType rtype,
const array& input,
array& output,
int axis,
bool reverse,
bool inclusive) {
switch (rtype) {
case Scan::Sum: {
auto op = [](U y, T x) { return y + x; };
auto init = static_cast<U>(0);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
break;
}
case Scan::Prod: {
auto op = [](U y, T x) { return y * x; };
auto init = static_cast<U>(1);
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
break;
}
case Scan::Min: {
auto op = [](U y, T x) { return x < y ? x : y; };
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
break;
}
case Scan::Max: {
auto op = [](U y, T x) { return x < y ? y : x; };
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(input, output, axis, reverse, inclusive, op, init);
break;
}
}
}
} // namespace
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Ensure contiguity
auto in = inputs[0];
if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy(in, arr_copy, CopyType::General);
in = arr_copy;
}
switch (in.dtype()) {
case bool_: {
// We could do a full dtype x dtype switch but this is the only case
// where we accumulate in a different type, for now.
//
// TODO: If we add the option to accumulate floats in higher precision
// floats perhaps we should add the full all-to-all dispatch.
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
scan_dispatch<bool, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
} else {
scan_dispatch<bool, bool>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
}
break;
}
case uint8:
scan_dispatch<uint8_t, uint8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint16:
scan_dispatch<uint16_t, uint16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint32:
scan_dispatch<uint32_t, uint32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case uint64:
scan_dispatch<uint64_t, uint64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int8:
scan_dispatch<int8_t, int8_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int16:
scan_dispatch<int16_t, int16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int32:
scan_dispatch<int32_t, int32_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case int64:
scan_dispatch<int64_t, int64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float16:
scan_dispatch<float16_t, float16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float32:
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case complex64:
throw std::runtime_error("Scan ops do not support complex types yet");
break;
}
}
} // namespace mlx::core

View File

@@ -1,56 +0,0 @@
#pragma once
#include "mlx/backend/cpu/simd/base_simd.h"
#if MLX_SIMD_LIBRARY_VERSION < 6
#include "mlx/backend/cpu/simd/neon_fp16_simd.h"
#endif
namespace mlx::core::simd {
#if MLX_SIMD_LIBRARY_VERSION >= 6
constexpr int N = 8;
template <int N>
struct ScalarT<float16_t, N> {
using v = _Float16;
};
#endif
template <>
static constexpr int max_size<float16_t> = N;
#define SIMD_FP16_DEFAULT_UNARY(op) \
template <> \
inline Simd<float16_t, N> op(Simd<float16_t, N> v) { \
Simd<float, N> in = v; \
return op(in); \
}
SIMD_FP16_DEFAULT_UNARY(acos)
SIMD_FP16_DEFAULT_UNARY(acosh)
SIMD_FP16_DEFAULT_UNARY(asin)
SIMD_FP16_DEFAULT_UNARY(asinh)
SIMD_FP16_DEFAULT_UNARY(atan)
SIMD_FP16_DEFAULT_UNARY(atanh)
SIMD_FP16_DEFAULT_UNARY(cosh)
SIMD_FP16_DEFAULT_UNARY(expm1)
SIMD_FP16_DEFAULT_UNARY(log)
SIMD_FP16_DEFAULT_UNARY(log2)
SIMD_FP16_DEFAULT_UNARY(log10)
SIMD_FP16_DEFAULT_UNARY(log1p)
SIMD_FP16_DEFAULT_UNARY(sinh)
SIMD_FP16_DEFAULT_UNARY(tan)
SIMD_FP16_DEFAULT_UNARY(tanh)
#define SIMD_FP16_DEFAULT_BINARY(op) \
template <> \
inline Simd<float16_t, N> op(Simd<float16_t, N> x, Simd<float16_t, N> y) { \
Simd<float, N> a = x; \
Simd<float, N> b = y; \
return op(a, b); \
}
SIMD_FP16_DEFAULT_BINARY(atan2)
SIMD_FP16_DEFAULT_BINARY(remainder)
SIMD_FP16_DEFAULT_BINARY(pow)
} // namespace mlx::core::simd

View File

@@ -1,308 +0,0 @@
#pragma once
#include <simd/math.h>
#include <simd/vector.h>
#include <stdint.h>
#include <cmath>
#include <complex>
#include "mlx/backend/cpu/simd/base_simd.h"
// There seems to be a bug in sims/base.h
// __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15
#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \
__IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \
__WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
__WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
__TV_OS_VERSION_MIN_REQUIRED >= 180000
#define MLX_SIMD_LIBRARY_VERSION 6
#else
#define MLX_SIMD_LIBRARY_VERSION 5
#endif
namespace mlx::core::simd {
// Apple simd namespace
namespace asd = ::simd;
// This indirection is needed to remap certain types to ones that accelerate
// SIMD can handle
template <typename T, int N>
struct ScalarT {
using v = T;
};
template <int N>
struct ScalarT<bool, N> {
using v = char;
};
template <int N>
struct ScalarT<int8_t, N> {
using v = char;
};
template <int N>
struct ScalarT<uint64_t, N> {
using v = unsigned long;
};
template <int N>
struct ScalarT<int64_t, N> {
using v = long;
};
template <typename T, int N>
struct Simd {
static constexpr int size = N;
using scalar_t = typename ScalarT<T, N>::v;
Simd<T, N>() {}
template <typename U>
Simd<T, N>(Simd<U, N> other) : value(asd::convert<scalar_t>(other.value)) {}
template <typename U>
Simd<T, N>(U v) : value(v){};
Simd<T, N>(Simd<T, N / 2> x, Simd<T, N / 2> y) {
value = asd::make<typename asd::Vector<scalar_t, N>::packed_t>(
x.value, y.value);
};
T operator[](int idx) const {
return reinterpret_cast<const T*>(&value)[idx];
}
T& operator[](int idx) {
return reinterpret_cast<T*>(&value)[idx];
}
typename asd::Vector<scalar_t, N>::packed_t value;
};
// Values chosen based on benchmarks on M3 Max
// TODO: consider choosing these more optimally
template <>
static constexpr int max_size<int8_t> = 16;
template <>
static constexpr int max_size<int16_t> = 16;
template <>
static constexpr int max_size<int> = 8;
template <>
static constexpr int max_size<int64_t> = 4;
template <>
static constexpr int max_size<uint8_t> = 16;
template <>
static constexpr int max_size<uint16_t> = 16;
template <>
static constexpr int max_size<uint32_t> = 8;
template <>
static constexpr int max_size<uint64_t> = 4;
template <>
static constexpr int max_size<float> = 8;
template <>
static constexpr int max_size<double> = 4;
#define SIMD_DEFAULT_UNARY(name, op) \
template <typename T, int N> \
Simd<T, N> name(Simd<T, N> v) { \
return op(v.value); \
}
SIMD_DEFAULT_UNARY(abs, asd::abs)
SIMD_DEFAULT_UNARY(floor, asd::floor)
SIMD_DEFAULT_UNARY(acos, asd::acos)
SIMD_DEFAULT_UNARY(acosh, asd::acosh)
SIMD_DEFAULT_UNARY(asin, asd::asin)
SIMD_DEFAULT_UNARY(asinh, asd::asinh)
SIMD_DEFAULT_UNARY(atan, asd::atan)
SIMD_DEFAULT_UNARY(atanh, asd::atanh)
SIMD_DEFAULT_UNARY(ceil, asd::ceil)
SIMD_DEFAULT_UNARY(cosh, asd::cosh)
SIMD_DEFAULT_UNARY(expm1, asd::expm1)
SIMD_DEFAULT_UNARY(log, asd::log)
SIMD_DEFAULT_UNARY(log2, asd::log2)
SIMD_DEFAULT_UNARY(log10, asd::log10)
SIMD_DEFAULT_UNARY(log1p, asd::log1p)
SIMD_DEFAULT_UNARY(rint, asd::rint)
SIMD_DEFAULT_UNARY(sinh, asd::sinh)
SIMD_DEFAULT_UNARY(sqrt, asd::sqrt)
SIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt)
SIMD_DEFAULT_UNARY(recip, asd::recip)
SIMD_DEFAULT_UNARY(tan, asd::tan)
SIMD_DEFAULT_UNARY(tanh, asd::tanh)
template <typename T, int N>
Simd<T, N> operator-(Simd<T, N> v) {
return -v.value;
}
template <typename T, int N>
Simd<T, N> operator~(Simd<T, N> v) {
return ~v.value;
}
template <typename T, int N>
Simd<bool, N> isnan(Simd<T, N> v) {
return asd::convert<char>(v.value != v.value);
}
// No simd_boolN in accelerate, use int8_t instead
template <typename T, int N>
Simd<bool, N> operator!(Simd<T, N> v) {
return asd::convert<char>(!v.value);
}
#define SIMD_DEFAULT_BINARY(OP) \
template <typename T, typename U, int N> \
Simd<T, N> operator OP(Simd<T, N> x, U y) { \
return asd::convert<typename Simd<T, N>::scalar_t>(x.value OP y); \
} \
template <typename T1, typename T2, int N> \
Simd<T2, N> operator OP(T1 x, Simd<T2, N> y) { \
return asd::convert<typename Simd<T2, N>::scalar_t>(x OP y.value); \
} \
template <typename T1, typename T2, int N> \
Simd<T1, N> operator OP(Simd<T1, N> x, Simd<T2, N> y) { \
return asd::convert<typename Simd<T1, N>::scalar_t>(x.value OP y.value); \
}
SIMD_DEFAULT_BINARY(+)
SIMD_DEFAULT_BINARY(-)
SIMD_DEFAULT_BINARY(/)
SIMD_DEFAULT_BINARY(*)
SIMD_DEFAULT_BINARY(<<)
SIMD_DEFAULT_BINARY(>>)
SIMD_DEFAULT_BINARY(|)
SIMD_DEFAULT_BINARY(^)
SIMD_DEFAULT_BINARY(&)
SIMD_DEFAULT_BINARY(&&)
SIMD_DEFAULT_BINARY(||)
#define SIMD_DEFAULT_COMPARISONS(OP) \
template <int N, typename T, typename U> \
Simd<bool, N> operator OP(Simd<T, N> a, U b) { \
return asd::convert<char>(a.value OP b); \
} \
template <int N, typename T, typename U> \
Simd<bool, N> operator OP(T a, Simd<U, N> b) { \
return asd::convert<char>(a OP b.value); \
} \
template <int N, typename T1, typename T2> \
Simd<bool, N> operator OP(Simd<T1, N> a, Simd<T2, N> b) { \
return asd::convert<char>(a.value OP b.value); \
}
SIMD_DEFAULT_COMPARISONS(>)
SIMD_DEFAULT_COMPARISONS(<)
SIMD_DEFAULT_COMPARISONS(>=)
SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value);
}
template <typename T, int N>
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan
return asd::max(a.value, b.value);
}
template <typename T, int N>
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan
return asd::min(a.value, b.value);
}
template <typename T, int N>
Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
Simd<T, N> r;
if constexpr (!std::is_integral_v<T>) {
r = asd::remainder(a.value, b.value);
} else {
r = a - b * (a / b);
}
if constexpr (std::is_signed_v<T>) {
auto mask = r != 0 && (r < 0 != b < 0);
r = select(mask, r + b, r);
}
return r;
}
template <typename MaskT, typename T1, typename T2, int N>
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
if constexpr (sizeof(T1) == 1) {
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
} else if constexpr (sizeof(T1) == 2) {
return asd::bitselect(y.value, x.value, asd::convert<short>(mask.value));
} else if constexpr (sizeof(T1) == 4) {
return asd::bitselect(y.value, x.value, asd::convert<int>(mask.value));
} else {
return asd::bitselect(y.value, x.value, asd::convert<long>(mask.value));
}
}
template <typename T, int N>
Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
if constexpr (!std::is_integral_v<T>) {
return asd::pow(base.value, exp.value);
} else {
Simd<T, N> res = 1;
while (any(exp)) {
res = select(exp & 1, res * base, res);
base = select(exp, base * base, base);
exp = exp >> 1;
}
return res;
}
}
template <typename T, int N>
Simd<T, N> clamp(Simd<T, N> v, Simd<T, N> min, Simd<T, N> max) {
return asd::clamp(v.value, min.value, max.value);
}
template <typename T, typename U, int N>
Simd<T, N> fma(Simd<T, N> x, Simd<T, N> y, U z) {
return asd::muladd(x.value, y.value, Simd<T, N>(z).value);
}
// Reductions
template <typename T, int N>
bool all(Simd<T, N> x) {
return asd::all(x.value);
}
template <typename T, int N>
bool any(Simd<T, N> x) {
return asd::any(x.value);
}
template <typename T, int N>
T sum(Simd<T, N> x) {
return asd::reduce_add(x.value);
}
template <typename T, int N>
T max(Simd<T, N> x) {
return asd::reduce_max(x.value);
}
template <typename T, int N>
T min(Simd<T, N> x) {
return asd::reduce_min(x.value);
}
template <typename T, int N>
T prod(Simd<T, N> x) {
auto ptr = (T*)&x;
auto lhs = load<T, N / 2>(ptr);
auto rhs = load<T, N / 2>(ptr + N / 2);
return prod(lhs * rhs);
}
} // namespace mlx::core::simd
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "mlx/backend/cpu/simd/accelerate_fp16_simd.h"
#endif

View File

@@ -1,259 +0,0 @@
#pragma once
#include <stdint.h>
#include <algorithm>
#include <cmath>
#include <complex>
#include <functional>
namespace mlx::core::simd {
template <typename T, int N>
struct Simd;
template <typename T>
static constexpr int max_size = 1;
template <typename T>
struct Simd<T, 1> {
static constexpr int size = 1;
T value;
Simd() {}
template <typename U>
Simd(Simd<U, 1> v) : value(v.value) {}
template <typename U>
Simd(U v) : value(v) {}
};
template <typename T, int N>
Simd<T, N> load(const T* x) {
return *(Simd<T, N>*)x;
}
template <typename T, int N>
void store(T* dst, Simd<T, N> x) {
// Maintain invariant that bool is either 0 or 1 as
// simd comparison ops set all bits in the result to 1
if constexpr (std::is_same_v<T, bool> && N > 1) {
x = x & 1;
}
*(Simd<T, N>*)dst = x;
}
template <typename, typename = void>
constexpr bool is_complex = false;
template <typename T>
constexpr bool is_complex<T, std::void_t<decltype(std::declval<T>().real())>> =
true;
template <typename T>
Simd<T, 1> rint(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
return Simd<T, 1>{
T{std::rint(in.value.real()), std::rint(in.value.imag())}};
} else {
return Simd<T, 1>{std::rint(in.value)};
}
}
template <typename T>
Simd<T, 1> rsqrt(Simd<T, 1> in) {
return T(1.0) / sqrt(in);
}
template <typename T>
Simd<T, 1> recip(Simd<T, 1> in) {
return T(1.0) / in;
}
#define DEFAULT_UNARY(name, op) \
template <typename T> \
Simd<T, 1> name(Simd<T, 1> in) { \
return op(in.value); \
}
DEFAULT_UNARY(operator-, std::negate{})
DEFAULT_UNARY(operator!, std::logical_not{})
DEFAULT_UNARY(abs, std::abs)
DEFAULT_UNARY(acos, std::acos)
DEFAULT_UNARY(acosh, std::acosh)
DEFAULT_UNARY(asin, std::asin)
DEFAULT_UNARY(asinh, std::asinh)
DEFAULT_UNARY(atan, std::atan)
DEFAULT_UNARY(atanh, std::atanh)
DEFAULT_UNARY(ceil, std::ceil)
DEFAULT_UNARY(conj, std::conj)
DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh)
DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value;
}
template <typename T>
auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
return std::real(in.value);
}
template <typename T>
auto imag(Simd<T, 1> in) -> Simd<decltype(std::imag(in.value)), 1> {
return std::imag(in.value);
}
template <typename T>
Simd<bool, 1> isnan(Simd<T, 1> in) {
return std::isnan(in.value);
}
#define DEFAULT_BINARY(OP) \
template <typename T1, typename T2> \
auto operator OP(Simd<T1, 1> a, Simd<T2, 1> b) \
->Simd<decltype(a.value OP b.value), 1> { \
return a.value OP b.value; \
} \
template <typename T1, typename T2> \
auto operator OP(T1 a, Simd<T2, 1> b)->Simd<decltype(a OP b.value), 1> { \
return a OP b.value; \
} \
template <typename T1, typename T2> \
auto operator OP(Simd<T1, 1> a, T2 b)->Simd<decltype(a.value OP b), 1> { \
return a.value OP b; \
}
DEFAULT_BINARY(+)
DEFAULT_BINARY(-)
DEFAULT_BINARY(*)
DEFAULT_BINARY(/)
DEFAULT_BINARY(<<)
DEFAULT_BINARY(>>)
DEFAULT_BINARY(|)
DEFAULT_BINARY(^)
DEFAULT_BINARY(&)
DEFAULT_BINARY(&&)
DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;
T b = b_.value;
T r;
if constexpr (std::is_integral_v<T>) {
r = a % b;
} else {
r = std::remainder(a, b);
}
if constexpr (std::is_signed_v<T>) {
if (r != 0 && (r < 0 != b < 0)) {
r += b;
}
}
return r;
}
template <typename T>
Simd<T, 1> maximum(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;
T b = b_.value;
if constexpr (!std::is_integral_v<T>) {
if (std::isnan(a)) {
return a;
}
}
return (a > b) ? a : b;
}
template <typename T>
Simd<T, 1> minimum(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;
T b = b_.value;
if constexpr (!std::is_integral_v<T>) {
if (std::isnan(a)) {
return a;
}
}
return (a < b) ? a : b;
}
template <typename T>
Simd<T, 1> pow(Simd<T, 1> a, Simd<T, 1> b) {
T base = a.value;
T exp = b.value;
if constexpr (!std::is_integral_v<T>) {
return std::pow(base, exp);
} else {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
}
template <typename T>
Simd<T, 1> atan2(Simd<T, 1> a, Simd<T, 1> b) {
return std::atan2(a.value, b.value);
}
#define DEFAULT_COMPARISONS(OP) \
template <typename T1, typename T2> \
Simd<bool, 1> operator OP(Simd<T1, 1> a, Simd<T2, 1> b) { \
return a.value OP b.value; \
} \
template <typename T1, typename T2> \
Simd<bool, 1> operator OP(T1 a, Simd<T2, 1> b) { \
return a OP b.value; \
} \
template <typename T1, typename T2> \
Simd<bool, 1> operator OP(Simd<T1, 1> a, T2 b) { \
return a.value OP b; \
}
DEFAULT_COMPARISONS(>)
DEFAULT_COMPARISONS(<)
DEFAULT_COMPARISONS(>=)
DEFAULT_COMPARISONS(<=)
DEFAULT_COMPARISONS(==)
DEFAULT_COMPARISONS(!=)
template <typename MaskT, typename T>
Simd<T, 1> select(Simd<MaskT, 1> mask, Simd<T, 1> x, Simd<T, 1> y) {
return mask.value ? x.value : y.value;
}
template <typename T>
Simd<T, 1> clamp(Simd<T, 1> v, Simd<T, 1> min, Simd<T, 1> max) {
return std::clamp(v.value, min.value, max.value);
}
template <typename T, typename U>
Simd<T, 1> fma(Simd<T, 1> x, Simd<T, 1> y, U z) {
return std::fma(x.value, y.value, Simd<T, 1>(z).value);
}
// Reductions
#define DEFAULT_REDUCTION(name, type) \
template <typename T> \
type name(Simd<T, 1> x) { \
return x.value; \
}
DEFAULT_REDUCTION(max, T)
DEFAULT_REDUCTION(min, T)
DEFAULT_REDUCTION(sum, T)
DEFAULT_REDUCTION(prod, T)
DEFAULT_REDUCTION(any, bool)
DEFAULT_REDUCTION(all, bool)
} // namespace mlx::core::simd

View File

@@ -1,193 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/cpu/simd/type.h"
namespace mlx::core::simd {
constexpr float inf = std::numeric_limits<float>::infinity();
/**
* Compute exp(x) in an optimizer friendly way as follows:
*
* First change the problem to computing 2**y where y = x / ln(2).
*
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
* `ipart` and y2 is fractional part. For the integer part we perform bit
* shifting and for the fractional part we use a polynomial approximation.
*
* The algorithm and constants of the polynomial taken from
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
* from Cephes math library.
*
* Note: The implementation below is a general fast exp. There could be faster
* implementations for numbers strictly < 0.
*/
template <typename T, int N>
Simd<T, N> exp(Simd<T, N> in) {
if constexpr (is_complex<T>) {
return Simd<T, 1>{std::exp(in.value)};
} else {
Simd<float, N> x_init = in;
auto x = x_init * 1.442695f; // multiply with log_2(e)
Simd<float, N> ipart, fpart;
ipart = floor(x + 0.5);
fpart = x - ipart;
x = 1.535336188319500e-4f;
x = fma(x, fpart, 1.339887440266574e-3f);
x = fma(x, fpart, 9.618437357674640e-3f);
x = fma(x, fpart, 5.550332471162809e-2f);
x = fma(x, fpart, 2.402264791363012e-1f);
x = fma(x, fpart, 6.931472028550421e-1f);
x = fma(x, fpart, 1.000000000000000f);
// generate 2**ipart in the floating point representation using integer
// bitshifting
Simd<int, N> epart = (Simd<int, N>(ipart) + 127) << 23;
// Deal with NaN and Inf
auto result = select(isnan(x_init), x_init, (*(Simd<float, N>*)&epart) * x);
result = select(x_init > 88.0f, Simd<float, N>(inf), result);
result = select(x_init < -88.0f, Simd<float, N>(0), result);
return Simd<T, N>(result);
}
}
/* Implementation from:
* https://github.com/JishinMaster/simd_utils/blob/3c1433a86fb38edcc9b02039f3c9a65b16640976/neon_mathfun.h#L357
* which originally came from the Cephes math library.
*/
template <bool Sine, typename T, int N>
Simd<T, N> sincos(Simd<T, N> in) {
auto sign_mask_sin = in < 0;
in = abs(in);
Simd<float, N> x = in;
// scale by 4/Pi
auto y = x * 1.27323954473516f;
// store the integer part of y in mm0
Simd<uint32_t, N> emm2 = y;
// j=(j+1) & (~1) (see the cephes sources)
emm2 = emm2 + 1;
emm2 = emm2 & ~1;
y = emm2;
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
auto poly_mask = (emm2 & 2) != 0;
// The magic pass: "Extended precision modular arithmetic"
// x = ((x - y * DP1) - y * DP2) - y * DP3
x = fma(y, Simd<float, N>(-0.78515625f), x);
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
// and the second polynom (Pi/4 <= x <= 0) in y2
auto z = x * x;
auto y1 =
fma(z, Simd<float, N>(2.443315711809948e-5f), -1.388731625493765e-3f);
auto y2 = fma(z, Simd<float, N>(-1.9515295891e-4f), 8.3321608736e-3f);
y1 = fma(y1, z, 4.166664568298827e-2f);
y2 = fma(y2, z, -1.6666654611e-1f);
y1 = y1 * z;
y2 = y2 * z;
y1 = y1 * z;
y2 = fma(x, y2, x);
y1 = fma(z, Simd<float, N>(-0.5f), y1);
y1 = y1 + 1.0f;
if constexpr (Sine) {
auto ys = select(poly_mask, y1, y2);
return select(sign_mask_sin, -ys, ys);
} else {
auto yc = select(poly_mask, y2, y1);
return select(sign_mask_cos, yc, -yc);
}
}
template <typename T, int N>
Simd<T, N> sin(Simd<T, N> x) {
if constexpr (is_complex<T>) {
return std::sin(x.value);
} else {
return sincos<true>(x);
}
}
template <typename T, int N>
Simd<T, N> cos(Simd<T, N> x) {
if constexpr (is_complex<T>) {
return std::cos(x.value);
} else {
return sincos<false>(x);
}
}
template <typename T, int N>
Simd<T, N> erf(Simd<T, N> x) {
// https://github.com/pytorch/pytorch/blob/abf28982a8cb43342e7669d859de9543fd804cc9/aten/src/ATen/cpu/vec/vec256/vec256_float.h#L175
Simd<float, N> v = x;
auto t = recip(fma(Simd<float, N>(0.3275911f), abs(v), 1.0f));
auto r = fma(Simd<float, N>(1.061405429f), t, -1.453152027f);
r = fma(r, t, 1.421413741f);
r = fma(r, t, -0.284496736f);
r = fma(r, t, 0.254829592f);
auto e = -exp(-v * v);
auto result = Simd<T, N>(fma(e * t, r, 1.0f));
return select(x > 0, result, -result);
}
template <typename T, int N>
Simd<T, N> erfinv(Simd<T, N> a_) {
Simd<float, N> a = a_;
auto t = fma(a, 0.0f - a, 1.0f);
t = log(t);
auto lhs = [](auto t) {
Simd<float, N> p;
p = 3.03697567e-10f; // 0x1.4deb44p-32
p = fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
p = fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
p = fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
p = fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
p = fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
p = fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
p = fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
return fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
};
auto rhs = [](auto t) {
Simd<float, N> p;
p = 5.43877832e-9f; // 0x1.75c000p-28
p = fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
p = fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
p = fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
p = fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
p = fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
p = fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
p = fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
p = fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
return fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
};
auto thresh = 6.125f;
// Compute both branches and select if N > 1
if constexpr (N == 1) {
if ((abs(t) > thresh).value) { // maximum ulp error = 2.35793
return a * lhs(t);
} else { // maximum ulp error = 2.35002
return a * rhs(t);
}
} else {
return a * select(t > thresh, lhs(t), rhs(t));
}
}
} // namespace mlx::core::simd

View File

@@ -1,212 +0,0 @@
#pragma once
#include <arm_neon.h>
#include "mlx/backend/cpu/simd/base_simd.h"
namespace mlx::core::simd {
constexpr int N = 8;
template <>
struct Simd<float16_t, N> {
static constexpr int size = N;
using scalar_t = float16_t;
Simd<float16_t, N>() {}
template <typename U>
Simd<float16_t, N>(U v) : value(vdupq_n_f16(v)){};
Simd<float16_t, N>(float16x8_t v) : value(v){};
Simd<float16_t, N>(Simd<float, N> other) {
auto f32x4_a = *(float32x4_t*)(&other);
auto f32x4_b = *((float32x4_t*)(&other) + 1);
value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b);
};
Simd<float16_t, N>(Simd<uint16_t, N> other) {
value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value));
};
operator Simd<int16_t, N>() {
auto v = vcvtq_s16_f16(value);
return load<int16_t, N>((int16_t*)&v);
};
operator Simd<float, N>() {
float32x4x2_t v;
v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value));
v.val[1] = vcvt_high_f32_f16(value);
return load<float, N>((float*)&v);
}
float16_t operator[](int idx) const {
return reinterpret_cast<const float16_t*>(&value)[idx];
}
float16_t& operator[](int idx) {
return reinterpret_cast<float16_t*>(&value)[idx];
}
float16x8_t value;
};
#define DEFINE_NEON_UNARY_OP(name, op) \
inline Simd<float16_t, N> name(Simd<float16_t, N> a) { \
return Simd<float16_t, N>{op(a.value)}; \
}
DEFINE_NEON_UNARY_OP(abs, vabsq_f16)
DEFINE_NEON_UNARY_OP(ceil, vrndpq_f16)
DEFINE_NEON_UNARY_OP(floor, vrndmq_f16)
DEFINE_NEON_UNARY_OP(sqrt, vsqrtq_f16)
DEFINE_NEON_UNARY_OP(rsqrt, vrsqrteq_f16)
DEFINE_NEON_UNARY_OP(recip, vrecpeq_f16)
DEFINE_NEON_UNARY_OP(rint, vrndnq_f16)
#define DEFINE_NEON_BINARY_OP(name, op) \
inline Simd<float16_t, N> name(Simd<float16_t, N> a, Simd<float16_t, N> b) { \
return op(a.value, b.value); \
} \
template <typename T> \
Simd<float16_t, N> name(Simd<float16_t, N> a, T b) { \
return op(a.value, Simd<float16_t, N>(b).value); \
} \
template <typename T> \
Simd<float16_t, N> name(T a, Simd<float16_t, N> b) { \
return op(Simd<float16_t, N>(a).value, b.value); \
}
inline Simd<float16_t, N> operator!(Simd<float16_t, N> v) {
auto out = vceqzq_f16(v.value);
return Simd<uint16_t, N>(*(uint16_t*)&out);
}
inline Simd<float16_t, N> operator-(Simd<float16_t, N> v) {
return vnegq_f16(v.value);
}
DEFINE_NEON_BINARY_OP(maximum, vmaxq_f16)
DEFINE_NEON_BINARY_OP(minimum, vminq_f16)
DEFINE_NEON_BINARY_OP(operator+, vaddq_f16)
DEFINE_NEON_BINARY_OP(operator-, vsubq_f16)
DEFINE_NEON_BINARY_OP(operator*, vmulq_f16)
DEFINE_NEON_BINARY_OP(operator/, vdivq_f16)
#define DEFINE_NEON_COMPARISON(Op, op) \
template <typename T> \
Simd<bool, N> operator Op(Simd<float16_t, N> a, T b) { \
auto out = op(a.value, Simd<float16_t, N>(b).value); \
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
} \
template <typename T> \
Simd<bool, N> operator Op(T a, Simd<float16_t, N> b) { \
auto out = op(Simd<float16_t, N>(a).value, b.value); \
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
} \
inline Simd<bool, N> operator Op( \
Simd<float16_t, N> a, Simd<float16_t, N> b) { \
auto out = op(a.value, b.value); \
return Simd<uint16_t, N>(*(uint16_t*)(&out)); \
}
DEFINE_NEON_COMPARISON(==, vceqq_f16)
DEFINE_NEON_COMPARISON(>=, vcgeq_f16)
DEFINE_NEON_COMPARISON(<=, vcleq_f16)
DEFINE_NEON_COMPARISON(>, vcgtq_f16)
DEFINE_NEON_COMPARISON(<, vcltq_f16)
template <typename T>
Simd<bool, N> operator!=(Simd<float16_t, N> a, T b) {
return !(a == b);
}
template <typename T>
Simd<bool, N> operator!=(T a, Simd<float16_t, N> b) {
return !(a == b);
}
inline Simd<bool, N> operator!=(Simd<float16_t, N> a, Simd<float16_t, N> b) {
return !(a == b);
}
inline Simd<float16_t, N> operator||(
Simd<float16_t, N> a,
Simd<float16_t, N> b) {
return Simd<uint16_t, N>((a != 0) || (b != 0));
}
template <typename T>
Simd<float16_t, N> operator||(Simd<float16_t, N> a, T b) {
return Simd<uint16_t, N>((a != 0) || (b != 0));
}
template <typename T>
Simd<float16_t, N> operator||(T a, Simd<float16_t, N> b) {
return Simd<uint16_t, N>((a != 0) || (b != 0));
}
inline Simd<float16_t, N> operator&&(
Simd<float16_t, N> a,
Simd<float16_t, N> b) {
return Simd<uint16_t, N>((a != 0) && (b != 0));
}
template <typename T>
Simd<float16_t, N> operator&&(Simd<float16_t, N> a, T b) {
return Simd<uint16_t, N>((a != 0) && (b != 0));
}
template <typename T>
Simd<float16_t, N> operator&&(T a, Simd<float16_t, N> b) {
return Simd<uint16_t, N>((a != 0) && (b != 0));
}
template <>
inline Simd<bool, N> isnan(Simd<float16_t, N> v) {
return v != v;
}
template <>
inline Simd<float16_t, N>
clamp(Simd<float16_t, N> v, Simd<float16_t, N> min, Simd<float16_t, N> max) {
return minimum(maximum(v, min), max);
}
template <typename T>
Simd<float16_t, N> fma(Simd<float16_t, N> x, Simd<float16_t, N> y, T z) {
return vfmaq_f16(x.value, y.value, Simd<float16_t, N>(z).value);
}
template <typename MaskT>
Simd<float16_t, N>
select(Simd<MaskT, N> mask, Simd<float16_t, N> x, Simd<float16_t, N> y) {
return vbslq_f16(Simd<uint16_t, N>(mask).value, x.value, y.value);
}
// Reductions
inline float16_t max(Simd<float16_t, N> x) {
float16x4_t y;
y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value));
y = vpmax_f16(y, y);
y = vpmax_f16(y, y);
return vget_lane_f16(y, 0);
}
inline float16_t min(Simd<float16_t, N> x) {
float16x4_t y;
y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value));
y = vpmin_f16(y, y);
y = vpmin_f16(y, y);
return vget_lane_f16(y, 0);
}
inline float16_t sum(Simd<float16_t, N> x) {
float16x4_t y;
y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value));
y = vpadd_f16(y, y);
y = vpadd_f16(y, y);
return vget_lane_f16(y, 0);
}
inline float16_t prod(Simd<float16_t, N> x) {
auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value));
auto out = hx[0];
hx[0] *= hx[1];
hx[0] *= hx[2];
hx[0] *= hx[3];
return hx[0];
}
} // namespace mlx::core::simd

View File

@@ -1,4 +0,0 @@
#pragma once
#include "mlx/backend/cpu/simd/math.h"
#include "mlx/backend/cpu/simd/type.h"

View File

@@ -1,7 +0,0 @@
#pragma once
#include "mlx/backend/cpu/simd/base_simd.h"
#ifdef MLX_USE_ACCELERATE
#include "mlx/backend/cpu/simd/accelerate_simd.h"
#endif

View File

@@ -1,21 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
std::tuple<int64_t, Strides> prepare_slice(
const array& in,
const Shape& start_indices,
const Shape& strides);
void shared_buffer_slice(
const array& in,
const Strides& out_strides,
size_t data_offset,
size_t data_size,
array& out);
} // namespace mlx::core

View File

@@ -1,157 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/ternary.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
void ternary_op_dims(
const T1* a,
const T2* b,
const T3* c,
U* out,
Op op,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& c_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
auto stride_c = c_strides[axis];
auto stride_out = out_strides[axis];
auto N = shape[axis];
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
a,
b,
c,
out,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
axis + 1);
} else {
*out = op(*a, *b, *c);
}
a += stride_a;
b += stride_b;
c += stride_c;
out += stride_out;
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dispatch_dims(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
auto [shape, strides] = collapse_contiguous_dims(
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
const auto& a_strides = strides[0];
const auto& b_strides = strides[1];
const auto& c_strides = strides[2];
const auto& out_strides = strides[3];
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<T3>();
int ndim = shape.size();
switch (ndim) {
case 1:
ternary_op_dims<T1, T2, T3, U, Op, 1>(
a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
case 2:
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr,
b_ptr,
c_ptr,
out_ptr,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
0);
return;
}
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
ContiguousIterator c_it(shape, c_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
c_ptr + c_it.loc,
out_ptr + elem,
op,
shape,
a_strides,
b_strides,
c_strides,
out_strides,
ndim - 2);
a_it.step();
b_it.step();
c_it.step();
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
// The full computation is scalar-scalar-scalar so we call the base op once.
if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
} else if (topt == TernaryOpType::VectorVectorVector) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* out_ptr = out.data<U>();
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;
c_ptr++;
out_ptr++;
}
} else {
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
}
}
} // namespace mlx::core

View File

@@ -1,300 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/cpu/unary.h"
#include "mlx/backend/cpu/unary_ops.h"
#include "mlx/primitives.h"
namespace mlx::core {
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
if (issubdtype(in.dtype(), unsignedinteger) || in.dtype() == bool_) {
// No-op for unsigned types
out.copy_shared_buffer(in);
} else {
auto op = detail::Abs{};
switch (out.dtype()) {
case int8:
unary_op<int8_t>(in, out, op);
break;
case int16:
unary_op<int16_t>(in, out, op);
break;
case int32:
unary_op<int32_t>(in, out, op);
break;
case int64:
unary_op<int64_t>(in, out, op);
break;
case float16:
unary_op<float16_t>(in, out, op);
break;
case float32:
unary_op<float>(in, out, op);
break;
case float64:
unary_op<double>(in, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, op);
break;
case complex64:
unary_op<complex64_t>(in, out, op);
break;
default:
throw std::runtime_error("[Abs] Called on unsigned type");
}
}
}
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCos());
}
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcCosh());
}
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSin());
}
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcSinh());
}
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTan());
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::ArcTanh());
}
void BitwiseInvert::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_int(in, out, detail::BitwiseInvert());
}
void Ceil::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Ceil());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Conjugate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
unary_op<complex64_t>(inputs[0], out, detail::Conjugate());
}
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Cos());
}
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Cosh());
}
void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::Erf());
break;
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case float64:
unary_op<double>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
default:
throw std::invalid_argument(
"[erf] Error function only defined for arrays"
" with real floating point type.");
}
}
void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
unary_op<float>(in, out, detail::ErfInv());
break;
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case float64:
unary_op<double>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;
default:
throw std::invalid_argument(
"[erf_inv] Inverse error function only defined for arrays"
" with real floating point type.");
}
}
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Exp());
}
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Expm1());
}
void Floor::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Floor());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (base_) {
case Base::e:
unary_fp(in, out, detail::Log());
break;
case Base::two:
unary_fp(in, out, detail::Log2());
break;
case Base::ten:
unary_fp(in, out, detail::Log10());
break;
}
}
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Log1p());
}
void LogicalNot::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::LogicalNot());
}
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Negative());
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Round::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_fp(in, out, detail::Round());
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sigmoid::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sigmoid());
}
void Sign::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == bool_) {
out.copy_shared_buffer(in);
} else {
unary(in, out, detail::Sign());
}
}
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sin());
}
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Sinh());
}
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
unary(in, out, detail::Square());
}
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (recip_) {
unary_fp(in, out, detail::Rsqrt());
} else {
unary_fp(in, out, detail::Sqrt());
}
}
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Tan());
}
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_fp(in, out, detail::Tanh());
}
} // namespace mlx::core

View File

@@ -1,109 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <stdint.h>
#include <cmath>
#include <complex>
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core::detail {
using namespace mlx::core::simd;
#define SINGLE() \
template <typename T> \
T operator()(T x) { \
return (*this)(Simd<T, 1>(x)).value; \
}
#define DEFAULT_OP(Op, op) \
struct Op { \
template <int N, typename T> \
Simd<T, N> operator()(Simd<T, N> x) { \
return simd::op(x); \
} \
SINGLE() \
};
DEFAULT_OP(Abs, abs)
DEFAULT_OP(ArcCos, acos)
DEFAULT_OP(ArcCosh, acosh)
DEFAULT_OP(ArcSin, asin)
DEFAULT_OP(ArcSinh, asinh)
DEFAULT_OP(ArcTan, atan)
DEFAULT_OP(ArcTanh, atanh)
DEFAULT_OP(BitwiseInvert, operator~)
DEFAULT_OP(Ceil, ceil)
DEFAULT_OP(Conjugate, conj)
DEFAULT_OP(Cos, cos)
DEFAULT_OP(Cosh, cosh)
DEFAULT_OP(Erf, erf)
DEFAULT_OP(ErfInv, erfinv)
DEFAULT_OP(Exp, exp)
DEFAULT_OP(Expm1, expm1)
DEFAULT_OP(Floor, floor);
DEFAULT_OP(Log, log);
DEFAULT_OP(Log2, log2);
DEFAULT_OP(Log10, log10);
DEFAULT_OP(Log1p, log1p);
DEFAULT_OP(LogicalNot, operator!)
DEFAULT_OP(Negative, operator-)
DEFAULT_OP(Round, rint);
DEFAULT_OP(Sin, sin)
DEFAULT_OP(Sinh, sinh)
DEFAULT_OP(Sqrt, sqrt)
DEFAULT_OP(Rsqrt, rsqrt)
DEFAULT_OP(Tan, tan)
DEFAULT_OP(Tanh, tanh)
struct Imag {
template <int N>
Simd<float, N> operator()(Simd<complex64_t, N> x) {
return simd::imag(x);
}
SINGLE()
};
struct Real {
template <int N>
Simd<float, N> operator()(Simd<complex64_t, N> x) {
return simd::real(x);
}
SINGLE()
};
struct Sigmoid {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
return 1.0f / (1.0f + simd::exp(-x));
}
SINGLE()
};
struct Sign {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
auto z = Simd<T, N>{0};
if constexpr (std::is_unsigned_v<T>) {
return x != z;
} else if constexpr (std::is_same_v<T, complex64_t>) {
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
} else {
return simd::select(
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
}
}
SINGLE()
};
struct Square {
template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) {
return x * x;
}
SINGLE()
};
} // namespace mlx::core::detail

View File

@@ -12,7 +12,7 @@ function(make_jit_source SRC_FILE)
add_custom_command(
OUTPUT jit/${SRC_NAME}.cpp
COMMAND
bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
${SRC_FILE}
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
@@ -35,8 +35,6 @@ make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(scatter kernels/indexing.h)
make_jit_source(gather kernels/indexing.h)
make_jit_source(gather_axis)
make_jit_source(scatter_axis)
make_jit_source(hadamard)
if(MLX_METAL_JIT)
@@ -91,7 +89,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp

View File

@@ -1,5 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include <format>
#include <sstream>
#include "mlx/backend/common/compiled.h"
@@ -11,8 +11,6 @@
#include "mlx/primitives.h"
#include "mlx/utils.h"
using namespace fmt::literals;
namespace mlx::core {
inline void build_kernel(
@@ -41,7 +39,7 @@ inline void build_kernel(
int cnt = 0;
// Start the kernel
os += fmt::format(
os += std::format(
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
// Add the input arguments
@@ -57,7 +55,7 @@ inline void build_kernel(
if (!is_scalar(x) && !contiguous) {
add_indices = true;
}
os += fmt::format(
os += std::format(
" device const {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
xname,
@@ -65,13 +63,13 @@ inline void build_kernel(
}
if (add_indices) {
os += fmt::format(
os += std::format(
" constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
}
// Add the output arguments
for (auto& x : outputs) {
os += fmt::format(
os += std::format(
" device {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
namer.get_name(x),
@@ -79,13 +77,13 @@ inline void build_kernel(
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os += fmt::format(
os += std::format(
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format(
os += std::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
}
if (dynamic_dims) {
os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
os += std::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
}
// The thread index in the whole grid
@@ -98,15 +96,15 @@ inline void build_kernel(
// a third grid dimension
os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
} else if (work_per_thread > 1) {
os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
os += fmt::format(
os += std::format(" constexpr int N_ = {0};\n", work_per_thread);
os += std::format(
" int xshape = output_shape[{0}];\n",
dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
os += fmt::format(
os += std::format(
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
} else {
os += fmt::format(
os += std::format(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
}
@@ -121,16 +119,16 @@ inline void build_kernel(
auto type_str = get_type_string(x.dtype());
std::ostringstream ss;
print_constant(ss, x);
os += fmt::format(
os += std::format(
" auto tmp_{0} = static_cast<{1}>({2});\n",
xname,
get_type_string(x.dtype()),
ss.str());
} else if (is_scalar(x)) {
os += fmt::format(
os += std::format(
" {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
} else if (contiguous) {
os += fmt::format(
os += std::format(
" {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
} else {
nc_inputs.push_back(x);
@@ -140,30 +138,30 @@ inline void build_kernel(
// Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
os += std::format(" {0} index_{1} = ", idx_type, xname);
if (ndim == 1) {
int offset = i * ndim;
os +=
fmt::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
std::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
} else if (ndim == 2) {
int offset = i * ndim;
os += fmt::format(
os += std::format(
"elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
idx_type,
offset);
} else if (ndim == 3) {
int offset = i * ndim;
os += fmt::format(
os += std::format(
"elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
} else if (!dynamic_dims) {
int offset = (i + 1) * ndim;
os += fmt::format(
os += std::format(
"N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
idx_type,
offset - 1,
offset - 2);
} else {
os += fmt::format(
os += std::format(
"N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
idx_type,
i);
@@ -175,18 +173,18 @@ inline void build_kernel(
if (dynamic_dims) {
os += " for (int d = ndim - 3; d >= 0; --d) {\n";
} else {
os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
os += std::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
}
os += " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" index_{0} += ", xname);
os += std::format(" index_{0} += ", xname);
if (dynamic_dims) {
os +=
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
std::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
} else {
os +=
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
std::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
}
}
os += " zpos /= output_shape[d];\n }\n";
@@ -202,16 +200,16 @@ inline void build_kernel(
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
os += fmt::format(
os += std::format(
" {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
}
// Actually write the computation
for (auto& x : tape) {
os += fmt::format(
os += std::format(
" {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
if (is_static_cast(x.primitive())) {
os += fmt::format(
os += std::format(
"static_cast<{0}>(tmp_{1});\n",
get_type_string(x.dtype()),
namer.get_name(x.inputs()[0]));
@@ -221,15 +219,15 @@ inline void build_kernel(
os += ss.str();
os += "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
os += std::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
}
os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
os += std::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
os += std::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
}
// Increment indices and close per thread loop
if (work_per_thread > 1) {
@@ -237,10 +235,10 @@ inline void build_kernel(
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
if (!dynamic_dims) {
os += fmt::format(
os += std::format(
" index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
} else {
os += fmt::format(
os += std::format(
" index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
}
}

View File

@@ -533,45 +533,6 @@ void implicit_gemm_conv_2D_general_gpu(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_fused_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies_w) {
int O_c = conv_params.O;
int C_c = conv_params.C;
int N_tiles_n = conv_params.N;
int N_tiles_h = (conv_params.oS[0] + 1) / 2;
int N_tiles_w = (conv_params.oS[1] + 1) / 2;
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
int bc = 32;
int wm = 4;
int wn = 1;
std::ostringstream kname;
kname << "winograd_conv_2d_fused_" << type_to_name(out) << "_flip"
<< conv_params.flip;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(conv_params, 3);
MTL::Size group_dims = MTL::Size(8, 8, 2);
MTL::Size grid_dims =
MTL::Size(O_c / 8, (N_tiles_h * N_tiles_w) / 8, N_tiles_n);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
const Stream& s,
metal::Device& d,
@@ -580,6 +541,67 @@ void winograd_conv_2D_gpu(
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies_w) {
Shape padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
conv_params.C};
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
// Fill with zeros
array zero_arr = array(0, in.dtype());
fill_gpu(zero_arr, in_padded, s);
copies_w.push_back(zero_arr);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
conv_params.pad[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
copies_w.push_back(in_padded_slice);
copies_w.push_back(in_padded);
MLXConvParams<2> conv_params_updated{
/* const int N = */ static_cast<int>(in_padded.shape(0)),
/* const int C = */ static_cast<int>(in_padded.shape(3)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */
{static_cast<int>(in_padded.shape(1)),
static_cast<int>(in_padded.shape(2))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
/* const int str[NDIM] = */ {1, 1},
/* const int pad[NDIM] = */ {0, 0},
/* const int kdil[NDIM] = */ {1, 1},
/* const int idil[NDIM] = */ {1, 1},
/* const size_t in_strides[NDIM + 2] = */
{in_padded.strides()[0],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
/* const int groups = */ 1,
/* const bool flip = */ false,
};
int O_c = conv_params.O;
int C_c = conv_params.C;
@@ -598,7 +620,7 @@ void winograd_conv_2D_gpu(
int bo = 4;
std::ostringstream kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
<< bc << "_flip" << conv_params.flip;
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
@@ -631,10 +653,10 @@ void winograd_conv_2D_gpu(
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_output_array(inp_wg, 1);
compute_encoder.set_bytes(conv_params, 2);
compute_encoder.set_bytes(conv_params_updated, 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
@@ -681,7 +703,7 @@ void winograd_conv_2D_gpu(
compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(conv_params, 2);
compute_encoder.set_bytes(conv_params_updated, 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
@@ -745,18 +767,14 @@ void conv_2D_gpu(
}
// Direct to winograd conv
bool img_large =
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && is_stride_one &&
is_kdil_one && is_idil_one) {
if (img_large && channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
if (conv_params.N <= 1) {
return winograd_conv_2D_fused_gpu(s, d, in, wt, out, conv_params, copies);
}
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
@@ -858,40 +876,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
wt = arr_copy;
}
// Check for 1x1 conv
auto is_one = [](int x) { return x == 1; };
auto is_zero = [](int x) { return x == 0; };
if (groups_ == 1 && (wt.shape(0) * wt.shape(-1) == wt.size()) &&
std::all_of(wt.shape().begin() + 1, wt.shape().end() - 1, is_one) &&
std::all_of(kernel_strides_.begin(), kernel_strides_.end(), is_one) &&
std::all_of(input_dilation_.begin(), input_dilation_.end(), is_one) &&
std::all_of(kernel_dilation_.begin(), kernel_dilation_.end(), is_one) &&
std::all_of(padding_.begin(), padding_.end(), is_zero)) {
std::vector<array> empty_copies;
steel_matmul_regular(
s,
d,
/*a = */ in,
/*b = */ wt,
/*c = */ out,
/*M = */ in.size() / in.shape(-1),
/*N = */ wt.shape(0),
/*K = */ in.shape(-1),
/*batch_size_out = */ 1,
/*lda = */ in.shape(-1),
/*ldb = */ wt.shape(-1),
/*ldd = */ wt.shape(0),
/*transpose_a = */ false,
/*transpose_b = */ true,
/*batch_shape = */ {1},
/*batch_strides = */ {1},
/*A_batch_stride = */ 0,
/*B_batch_stride = */ 0,
/*matrix_stride_out = */ 0,
/*copies = */ empty_copies);
}
// 3D conv
else if (out.ndim() == 5) {
if (out.ndim() == 5) {
conv_3D_gpu(
s,
d,

View File

@@ -2,7 +2,6 @@
#include <sstream>
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"

View File

@@ -15,11 +15,10 @@ void CustomKernel::eval_gpu(
std::vector<array> copies;
for (auto& out : outputs) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (init_value_) {
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
}

View File

@@ -13,7 +13,6 @@
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
namespace mlx::core::metal {
@@ -125,8 +124,8 @@ MTL::Library* load_library(
} // namespace
CommandEncoder::CommandEncoder(DeviceStream& stream) : stream_(stream) {
enc_ = stream_.buffer->computeCommandEncoder(MTL::DispatchTypeConcurrent);
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc_->retain();
}
@@ -135,20 +134,11 @@ CommandEncoder::~CommandEncoder() {
enc_->release();
}
void CommandEncoder::set_buffer(
const MTL::Buffer* buf,
int idx,
int64_t offset /* = 0 */) {
enc_->setBuffer(buf, offset, idx);
}
void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
if (all_inputs_.insert(a.buffer().ptr()).second) {
stream_.buffer_sizes += a.data_size();
}
all_inputs_.insert(a.buffer().ptr());
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
needs_barrier_ =
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
@@ -165,10 +155,6 @@ void CommandEncoder::set_output_array(
int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
register_output_array(a);
}
void CommandEncoder::register_output_array(array& a) {
all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent_) {
@@ -193,7 +179,6 @@ void CommandEncoder::dispatch_threadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
stream_.buffer_ops++;
enc_->dispatchThreadgroups(grid_dims, group_dims);
}
@@ -201,44 +186,14 @@ void CommandEncoder::dispatch_threads(
MTL::Size grid_dims,
MTL::Size group_dims) {
maybeInsertBarrier();
stream_.buffer_ops++;
enc_->dispatchThreads(grid_dims, group_dims);
}
void CommandEncoder::barrier() {
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
}
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String());
auto arch = arch_.back();
switch (arch) {
case 'p': // phone
max_ops_per_buffer_ = 20;
max_mb_per_buffer_ = 40;
break;
case 'g': // base, pro
max_ops_per_buffer_ = 40;
max_mb_per_buffer_ = 40;
break;
case 's': // max
max_ops_per_buffer_ = 50;
max_mb_per_buffer_ = 50;
break;
case 'd': // ultra
max_ops_per_buffer_ = 50;
max_mb_per_buffer_ = 50;
break;
default: // default to medium
max_ops_per_buffer_ = 40;
max_mb_per_buffer_ = 40;
break;
}
max_ops_per_buffer_ = env::max_ops_per_buffer(max_ops_per_buffer_);
max_mb_per_buffer_ = env::max_mb_per_buffer(max_mb_per_buffer_);
}
Device::~Device() {
@@ -269,13 +224,12 @@ void Device::new_queue(int index) {
}
}
bool Device::command_buffer_needs_commit(int index) {
auto& stream = get_stream_(index);
if (stream.buffer_ops > max_ops_per_buffer_ ||
(stream.buffer_sizes >> 20) > max_mb_per_buffer_) {
return true;
}
return false;
int Device::get_command_buffer_ops(int index) {
return get_stream_(index).buffer_ops;
}
void Device::increment_command_buffer_ops(int index) {
get_stream_(index).buffer_ops++;
}
MTL::CommandBuffer* Device::get_command_buffer(int index) {
@@ -298,7 +252,6 @@ void Device::commit_command_buffer(int index) {
stream.buffer->release();
stream.buffer = nullptr;
stream.buffer_ops = 0;
stream.buffer_sizes = 0;
}
void Device::add_temporary(array arr, int index) {
@@ -383,7 +336,7 @@ void Device::end_encoding(int index) {
CommandEncoder& Device::get_command_encoder(int index) {
auto& stream = get_stream_(index);
if (stream.encoder == nullptr) {
stream.encoder = std::make_unique<CommandEncoder>(stream);
stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
stream.fence = std::make_shared<Fence>(device_->newFence());
}
return *stream.encoder;

View File

@@ -38,10 +38,8 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
struct DeviceStream;
struct CommandEncoder {
explicit CommandEncoder(DeviceStream& stream);
CommandEncoder(MTL::CommandBuffer* cbuf);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
@@ -62,11 +60,9 @@ struct CommandEncoder {
void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0);
void register_output_array(array& a);
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier();
void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0);
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
enc_->setComputePipelineState(kernel);
@@ -114,10 +110,7 @@ struct CommandEncoder {
return all_outputs_;
};
void barrier();
private:
DeviceStream& stream_;
MTL::ComputeCommandEncoder* enc_;
bool needs_barrier_{false};
bool concurrent_{false};
@@ -150,10 +143,10 @@ struct DeviceStream {
// Used to allow thread-safe access to the outputs map
std::mutex fence_mtx;
// Data updated between command buffers
// The buffer and buffer op count are updated
// between command buffers
MTL::CommandBuffer* buffer{nullptr};
int buffer_ops{0};
size_t buffer_sizes{0};
// The command encoder, fence, and temporaries are updated between command
// encoders
@@ -179,7 +172,8 @@ class Device {
void new_queue(int index);
MTL::CommandBuffer* get_command_buffer(int index);
bool command_buffer_needs_commit(int index);
int get_command_buffer_ops(int index);
void increment_command_buffer_ops(int index);
void commit_command_buffer(int index);
CommandEncoder& get_command_encoder(int index);
void end_encoding(int index);
@@ -269,8 +263,6 @@ class Device {
std::unordered_map<std::string, MTL::Library*> library_map_;
const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_;
int max_ops_per_buffer_;
int max_mb_per_buffer_;
};
Device& device(mlx::core::Device);

Some files were not shown because too many files have changed in this diff Show More