mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-12 23:34:36 +08:00
Compare commits
41 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
1600092e92 | ||
![]() |
dba2bd1105 | ||
![]() |
28be4de7c2 | ||
![]() |
a6c3b38fba | ||
![]() |
fcb65a3897 | ||
![]() |
4e22a1dffe | ||
![]() |
291cf40aca | ||
![]() |
bd47e1f066 | ||
![]() |
e6b223df5f | ||
![]() |
e64349bbdd | ||
![]() |
cdb59faea6 | ||
![]() |
1d94ac3f90 | ||
![]() |
5f7d19d1f5 | ||
![]() |
2fdf9eb535 | ||
![]() |
860d3a50d7 | ||
![]() |
d1183821a7 | ||
![]() |
8081df79be | ||
![]() |
64bec4fad7 | ||
![]() |
b96e105244 | ||
![]() |
3b4d5484c7 | ||
![]() |
684e11c664 | ||
![]() |
b57a52813b | ||
![]() |
da8deb2b62 | ||
![]() |
98b6ce3460 | ||
![]() |
f9e00efe31 | ||
![]() |
0fd2a1f4b0 | ||
![]() |
df3233454d | ||
![]() |
82db84b899 | ||
![]() |
8ae751d3da | ||
![]() |
d40e76809f | ||
![]() |
bb1b76d9dc | ||
![]() |
9d26441224 | ||
![]() |
f12f24a77c | ||
![]() |
ae5b5cabfd | ||
![]() |
d0630ffe8c | ||
![]() |
99bb7d3a58 | ||
![]() |
63ae767232 | ||
![]() |
eaaea02010 | ||
![]() |
a098bc92e0 | ||
![]() |
1086dc4db0 | ||
![]() |
19fb69e2ed |
@@ -31,7 +31,7 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install nanobind==2.1.0
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
@@ -44,6 +44,7 @@ jobs:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
@@ -76,7 +77,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install nanobind==2.1.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
@@ -90,6 +91,7 @@ jobs:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
@@ -123,6 +125,12 @@ jobs:
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
|
||||
make -j
|
||||
- run:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" CMAKE_ARGS="-DMLX_METAL_JIT=ON" pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
@@ -149,7 +157,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install nanobind==2.1.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
@@ -165,6 +173,7 @@ jobs:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
@@ -213,7 +222,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
pip install nanobind==2.1.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
@@ -222,6 +231,7 @@ jobs:
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
pip install . -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
|
@@ -1,11 +1,11 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v18.1.4
|
||||
rev: v18.1.8
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.4.2
|
||||
rev: 24.8.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
|
@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.16.2)
|
||||
set(MLX_VERSION 0.17.2)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
66
benchmarks/python/distributed_bench.py
Normal file
66
benchmarks/python/distributed_bench.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Run with:
|
||||
mpirun -n 2 python /path/to/distributed_bench.py
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def time_fn(fn, *args, **kwargs):
|
||||
msg = kwargs.pop("msg", None)
|
||||
world = mx.distributed.init()
|
||||
if world.rank() == 0:
|
||||
if msg:
|
||||
print(f"Timing {msg} ...", end=" ")
|
||||
else:
|
||||
print(f"Timing {fn.__name__} ...", end=" ")
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(*args, **kwargs))
|
||||
|
||||
num_iters = 100
|
||||
tic = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
x = mx.eval(fn(*args, **kwargs))
|
||||
toc = time.perf_counter()
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
if world.rank() == 0:
|
||||
print(f"{msec:.5f} msec")
|
||||
|
||||
|
||||
def time_all_sum():
|
||||
shape = (4096,)
|
||||
x = mx.random.uniform(shape=shape)
|
||||
mx.eval(x)
|
||||
|
||||
def sine(x):
|
||||
for _ in range(20):
|
||||
x = mx.sin(x)
|
||||
return x
|
||||
|
||||
time_fn(sine, x)
|
||||
|
||||
def all_sum_plain(x):
|
||||
for _ in range(20):
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
time_fn(all_sum_plain, x)
|
||||
|
||||
def all_sum_with_sine(x):
|
||||
for _ in range(20):
|
||||
x = mx.sin(x)
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
time_fn(all_sum_with_sine, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_all_sum()
|
413
docs/src/dev/custom_metal_kernels.rst
Normal file
413
docs/src/dev/custom_metal_kernels.rst
Normal file
@@ -0,0 +1,413 @@
|
||||
Custom Metal Kernels
|
||||
====================
|
||||
|
||||
MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||
|
||||
Simple Example
|
||||
--------------
|
||||
|
||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
)
|
||||
return outputs["out"]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
.. note::
|
||||
We are only required to pass the body of the Metal kernel in ``source``.
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
* The keys and shapes/dtypes of ``inputs``
|
||||
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
||||
so we will add ``const device float16_t* inp`` to the signature.
|
||||
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
||||
in ``source``.
|
||||
* The keys and values of ``output_shapes`` and ``output_dtypes``
|
||||
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
||||
so we add ``device float16_t* out``.
|
||||
* Template parameters passed using ``template``
|
||||
In the above, ``template={"T": mx.float32}`` adds a template of ``template <typename T>`` to the function
|
||||
and instantiates the template with ``custom_kernel_myexp_float<float>``.
|
||||
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
|
||||
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
|
||||
These will be added as function arguments.
|
||||
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
|
||||
|
||||
Putting this all together, the generated function signature for ``myexp`` is as follows:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void custom_kernel_myexp_float(
|
||||
const device float16_t* inp [[buffer(0)]],
|
||||
device float16_t* out [[buffer(1)]],
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
|
||||
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
|
||||
}
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
|
||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||
when indexing.
|
||||
|
||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||
input array ``a`` if any are present in ``source``.
|
||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
return outputs["out"]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
# make non-contiguous
|
||||
a = a[::2]
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Complex Example
|
||||
-----------------------------
|
||||
|
||||
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
|
||||
|
||||
We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
|
||||
return output
|
||||
|
||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"x": x, "grid": grid},
|
||||
template={"T": x.dtype},
|
||||
output_shapes={"out": out_shape},
|
||||
output_dtypes={"out": x.dtype},
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs["out"]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
|
||||
On an M1 Max, we see a big performance improvement:
|
||||
|
||||
``55.7ms -> 6.7ms => 8x speed up``
|
||||
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
|
||||
* ``atomic_outputs=True``
|
||||
Designate all of the kernel outputs as ``atomic`` in the function signature.
|
||||
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
|
||||
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
|
||||
|
||||
We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs={"x": x, "grid": grid, "cotangent": cotangent},
|
||||
template={"T": x.dtype},
|
||||
output_shapes={"x_grad": x.shape, "grid_grad": grid.shape},
|
||||
output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype},
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs["x_grad"], outputs["grid_grad"]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
``676.4ms -> 16.7ms => 40x speed up``
|
@@ -85,3 +85,4 @@ are the CPU and GPU.
|
||||
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
dev/custom_metal_kernels
|
||||
|
@@ -17,3 +17,6 @@ made available.
|
||||
init
|
||||
all_sum
|
||||
all_gather
|
||||
send
|
||||
recv
|
||||
recv_like
|
||||
|
@@ -12,3 +12,5 @@ Fast
|
||||
layer_norm
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
affine_quantize
|
||||
metal_kernel
|
||||
|
@@ -44,6 +44,7 @@ Operations
|
||||
convolve
|
||||
conv1d
|
||||
conv2d
|
||||
conv3d
|
||||
conv_general
|
||||
cos
|
||||
cosh
|
||||
|
@@ -2,7 +2,7 @@
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.24",
|
||||
"mlx>=0.9.0",
|
||||
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
|
||||
"mlx>=0.17.0",
|
||||
"nanobind==2.1.0",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
mlx>=0.16.2
|
||||
nanobind==2.0
|
||||
mlx>=0.17.0
|
||||
nanobind==2.1.0
|
||||
|
@@ -13,7 +13,6 @@ if __name__ == "__main__":
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages=["mlx_sample_extensions"],
|
||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||
extras_require={"dev": []},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.8",
|
||||
)
|
||||
|
@@ -70,7 +70,6 @@ inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||
|
||||
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <list>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
|
@@ -5,11 +5,9 @@
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/io/load.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <const uint8_t scalar_size>
|
||||
@@ -29,12 +27,14 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
namespace mlx::core {
|
||||
|
||||
reader_->seek(offset_, std::ios_base::beg);
|
||||
reader_->read(out.data<char>(), out.nbytes());
|
||||
void load(
|
||||
array& out,
|
||||
size_t offset,
|
||||
const std::shared_ptr<io::Reader>& reader,
|
||||
bool swap_endianness_) {
|
||||
reader->read(out.data<char>(), out.nbytes(), offset);
|
||||
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
@@ -51,4 +51,11 @@ void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
load(out, offset_, reader_, swap_endianness_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
14
mlx/backend/common/load.h
Normal file
14
mlx/backend/common/load.h
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/io/load.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void load(
|
||||
array& out,
|
||||
size_t offset,
|
||||
const std::shared_ptr<io::Reader>& reader,
|
||||
bool swap_endianess);
|
||||
|
||||
} // namespace mlx::core
|
@@ -21,7 +21,7 @@ EOM
|
||||
|
||||
fi
|
||||
|
||||
CONTENT=$($GCC -I $SRCDIR -E $SRCDIR/mlx/backend/common/compiled_preamble.h 2>/dev/null)
|
||||
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
const char* get_kernel_preamble() {
|
||||
|
@@ -373,6 +373,10 @@ struct Sign {
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x != 0;
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return x == complex64_t(0) ? x : x / std::abs(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
|
@@ -87,6 +87,38 @@ struct OrReduce {
|
||||
}
|
||||
};
|
||||
|
||||
struct MaxReduce {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
if (std::isnan(x)) {
|
||||
*y = x;
|
||||
} else {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct MinReduce {
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
||||
if (std::isnan(x)) {
|
||||
*y = x;
|
||||
} else {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InT>
|
||||
void reduce_dispatch_out(
|
||||
const array& in,
|
||||
@@ -118,15 +150,13 @@ void reduce_dispatch_out(
|
||||
break;
|
||||
}
|
||||
case Reduce::Max: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, op);
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
break;
|
||||
}
|
||||
case Reduce::Min: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, op);
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@@ -49,7 +49,7 @@ struct ReductionPlan {
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
// Should this be in utils?
|
||||
|
@@ -19,7 +19,7 @@ std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
@@ -41,6 +41,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
}
|
||||
}
|
||||
|
||||
// Remove singleton axes from the plan
|
||||
for (int i = shape.size() - 1; i >= 0; i--) {
|
||||
if (shape[i] == 1) {
|
||||
shape.erase(shape.begin() + i);
|
||||
strides.erase(strides.begin() + i);
|
||||
}
|
||||
}
|
||||
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
@@ -63,10 +71,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
if (x.shape(a) > 1) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
return a.second > b.second;
|
||||
bool a_is_zero = a.second == 0;
|
||||
bool b_is_zero = b.second == 0;
|
||||
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
@@ -98,16 +110,33 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int size = 1;
|
||||
bool have_expand = false;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
if (x.strides()[i] != size) {
|
||||
|
||||
size_t stride_i = x.strides()[i];
|
||||
int shape_i = x.shape(i);
|
||||
if (stride_i == 0) {
|
||||
if (shape_i == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
have_expand = true;
|
||||
break;
|
||||
}
|
||||
size *= x.shape(i);
|
||||
|
||||
if (stride_i != size && shape_i != 1) {
|
||||
break;
|
||||
}
|
||||
size *= shape_i;
|
||||
}
|
||||
if (size >= strides.back()) {
|
||||
// In the case of an expanded dimension we are being conservative and
|
||||
// require the smallest reduction stride to be smaller than the maximum row
|
||||
// contiguous size. The reason is that we can't easily know if the reduced
|
||||
// axis is before or after an expanded dimension.
|
||||
if (size > strides.back() || (size == strides.back() && !have_expand)) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
@@ -12,6 +12,7 @@ namespace {
|
||||
// TODO: Add support for more combinations of input types.
|
||||
enum class TernaryOpType {
|
||||
ScalarScalarScalar,
|
||||
VectorVectorVector,
|
||||
General,
|
||||
};
|
||||
|
||||
@@ -20,6 +21,12 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
|
||||
TernaryOpType topt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
|
||||
topt = TernaryOpType::ScalarScalarScalar;
|
||||
} else if (
|
||||
(a.flags().row_contiguous && b.flags().row_contiguous &&
|
||||
c.flags().row_contiguous) ||
|
||||
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
||||
c.flags().col_contiguous)) {
|
||||
topt = TernaryOpType::VectorVectorVector;
|
||||
} else {
|
||||
topt = TernaryOpType::General;
|
||||
}
|
||||
@@ -33,11 +40,32 @@ void set_ternary_op_output_data(
|
||||
array& out,
|
||||
TernaryOpType topt,
|
||||
bool donate_with_move = false) {
|
||||
auto maybe_donate = [&out, donate_with_move](const array& x) {
|
||||
if (x.is_donatable() && x.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(x);
|
||||
} else {
|
||||
out.copy_shared_buffer(x);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
switch (topt) {
|
||||
case TernaryOpType::ScalarScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
||||
break;
|
||||
case TernaryOpType::VectorVectorVector:
|
||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
}
|
||||
break;
|
||||
case TernaryOpType::General:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
break;
|
||||
|
@@ -104,6 +104,33 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||
}
|
||||
|
||||
// The single array version of the above.
|
||||
inline std::tuple<std::vector<int>, std::vector<size_t>>
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
std::vector<int> collapsed_shape;
|
||||
std::vector<size_t> collapsed_strides;
|
||||
|
||||
if (shape.size() > 0) {
|
||||
collapsed_shape.push_back(shape[0]);
|
||||
collapsed_strides.push_back(strides[0]);
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
if (strides[i] * shape[i] != collapsed_strides.back() ||
|
||||
collapsed_shape.back() * static_cast<size_t>(shape[i]) >
|
||||
std::numeric_limits<int>::max()) {
|
||||
collapsed_shape.push_back(shape[i]);
|
||||
collapsed_strides.push_back(strides[i]);
|
||||
} else {
|
||||
collapsed_shape.back() *= shape[i];
|
||||
collapsed_strides.back() = strides[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(collapsed_shape, collapsed_strides);
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
inline auto check_contiguity(
|
||||
const std::vector<int>& shape,
|
||||
@@ -115,8 +142,8 @@ inline auto check_contiguity(
|
||||
bool is_col_contiguous = true;
|
||||
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
is_col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
|
@@ -79,6 +79,7 @@ if (MLX_METAL_JIT)
|
||||
kernels/reduction/reduce_all.h
|
||||
kernels/reduction/reduce_col.h
|
||||
kernels/reduction/reduce_row.h
|
||||
kernels/reduction/reduce_init.h
|
||||
)
|
||||
make_jit_source(
|
||||
steel/gemm/gemm
|
||||
@@ -131,6 +132,8 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
|
89
mlx/backend/metal/custom_kernel.cpp
Normal file
89
mlx/backend/metal/custom_kernel.cpp
Normal file
@@ -0,0 +1,89 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void CustomKernel::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
for (auto& out : outputs) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (init_value_) {
|
||||
array init = array(init_value_.value(), out.dtype());
|
||||
copy_gpu(init, out, CopyType::Scalar, s);
|
||||
copies.push_back(init);
|
||||
}
|
||||
}
|
||||
|
||||
auto check_input = [&copies, &s, this](const array& x) -> const array {
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (!ensure_row_contiguous_ || no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
std::vector<const array> checked_inputs;
|
||||
for (const array& in : inputs) {
|
||||
checked_inputs.push_back(check_input(in));
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = name_;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
lib = d.get_library(lib_name, metal::utils() + source_);
|
||||
}
|
||||
auto kernel = d.get_kernel(name_, lib);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
int index = 0;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
auto shape_info = shape_infos_[i];
|
||||
compute_encoder.set_input_array(in, index);
|
||||
index++;
|
||||
if (in.ndim() > 0) {
|
||||
int ndim = in.ndim();
|
||||
if (shape_info.shape) {
|
||||
set_vector_bytes(compute_encoder, in.shape(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
set_vector_bytes(compute_encoder, in.strides(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), index);
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (array out : outputs) {
|
||||
compute_encoder.set_output_array(out, index);
|
||||
index++;
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
MTL::Size group_dims = MTL::Size(tx, ty, tz);
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
if (!copies.empty()) {
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -1,8 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
|
||||
#include <sys/sysctl.h>
|
||||
@@ -16,8 +14,6 @@
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
namespace {
|
||||
@@ -38,20 +34,6 @@ constexpr auto get_metal_version() {
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
std::string lib_ext = lib_name + ".metallib";
|
||||
|
||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||
if (success) {
|
||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||
mtllib_path = mtllib.c_str();
|
||||
}
|
||||
|
||||
return mtllib_path;
|
||||
}
|
||||
|
||||
auto load_device() {
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||
@@ -311,12 +293,6 @@ void Device::register_library(
|
||||
}
|
||||
}
|
||||
|
||||
void Device::register_library(const std::string& lib_name) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
}
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib;
|
||||
|
@@ -3,6 +3,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <Metal/Metal.hpp>
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
@@ -12,8 +14,26 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
// Note, this function must be left inline in a header so that it is not
|
||||
// dynamically linked.
|
||||
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
std::string lib_ext = lib_name + ".metallib";
|
||||
|
||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||
if (success) {
|
||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||
mtllib_path = mtllib.c_str();
|
||||
}
|
||||
|
||||
return mtllib_path;
|
||||
}
|
||||
|
||||
using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
@@ -86,7 +106,13 @@ class Device {
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path);
|
||||
|
||||
void register_library(const std::string& lib_name);
|
||||
// Note, this should remain in the header so that it is not dynamically
|
||||
// linked
|
||||
void register_library(const std::string& lib_name) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
}
|
||||
}
|
||||
|
||||
MTL::Library* get_library(const std::string& name);
|
||||
|
||||
|
142
mlx/backend/metal/distributed.cpp
Normal file
142
mlx/backend/metal/distributed.cpp
Normal file
@@ -0,0 +1,142 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
void signal_and_wait(const array& in, const array& out, const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
if (in.event().valid()) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
||||
in.event().value());
|
||||
}
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
}
|
||||
|
||||
void AllReduce::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto task = [in = in,
|
||||
out = out,
|
||||
reduce_type = reduce_type_,
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
switch (reduce_type) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(
|
||||
group, in.data_shared_ptr() == nullptr ? out : in, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
signal_and_wait(in, out, stream());
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto task = [in = in, out = out, group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
signal_and_wait(in, out, stream());
|
||||
}
|
||||
|
||||
void Send::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Schedule an async send on the comm stream
|
||||
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::send(group, in, dst);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
// Encode a signal event for the input but not a wait since we don't need to
|
||||
// wait on the output.
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
d.end_encoding(s.index);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
if (in.event().valid()) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
||||
in.event().value());
|
||||
}
|
||||
}
|
||||
|
||||
void Recv::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& out = outputs[0];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Schedule an async recv on the comm stream
|
||||
auto task = [out = out, group = group(), src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
// Encode a wait event as there is no input for the recv to encode a signal.
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
@@ -546,8 +546,8 @@ void fft_op(
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(x.shape(), strides);
|
||||
|
||||
flags.col_contiguous = is_row_contiguous;
|
||||
flags.row_contiguous = is_col_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.contiguous = data_size == x_copy.size();
|
||||
|
||||
x_copy.set_data(
|
||||
|
@@ -95,11 +95,21 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
slice_size *= s;
|
||||
}
|
||||
|
||||
// Launch 2D grid of threads: indices x slice
|
||||
size_t dim0 = out.size() / slice_size;
|
||||
size_t dim1 = slice_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
|
||||
// Launch 3D grid of threads
|
||||
// First two dimensions for the indices, the last one for the slice
|
||||
size_t dim0 = 1;
|
||||
size_t dim1 = 1;
|
||||
if (nidx) {
|
||||
if (inputs[1].ndim() >= 1) {
|
||||
dim0 = inputs[1].shape(0);
|
||||
}
|
||||
if (inputs[1].ndim() >= 2) {
|
||||
dim1 = inputs[1].size() / dim0;
|
||||
}
|
||||
}
|
||||
size_t dim2 = slice_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
|
@@ -13,8 +13,8 @@ constexpr std::string_view gather_kernels = R"(
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int& idx_ndim [[buffer(9)]],
|
||||
{4}
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {{
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {{
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
|
||||
|
@@ -1,168 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view reduce_init_kernels = R"(
|
||||
[[kernel]] void {0}(
|
||||
device {1}* out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {{
|
||||
out[tid] = {2}<{1}>::init;
|
||||
}}
|
||||
)";
|
||||
|
||||
constexpr std::string_view reduce_kernels = R"(
|
||||
template [[host_name("all_{0}")]] [[kernel]] void
|
||||
all_reduce<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("colGeneral_{0}")]] [[kernel]] void
|
||||
col_reduce_general<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup {2}* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
template [[host_name("colSmall_{0}")]] [[kernel]] void
|
||||
col_reduce_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
|
||||
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void
|
||||
row_reduce_general_med<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("rowGeneral_{0}")]] [[kernel]] void
|
||||
row_reduce_general<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view reduce_non_atomic_kernels = R"(
|
||||
template [[host_name("allNoAtomics_{0}")]] [[kernel]] void
|
||||
all_reduce_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void
|
||||
col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup {2}* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
template [[host_name("colSmall_{0}")]] [[kernel]] void
|
||||
col_reduce_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
|
||||
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void
|
||||
row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
@@ -6,7 +6,6 @@
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
@@ -323,12 +322,13 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
auto lib = d.get_library(kernel_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
<< fmt::format(
|
||||
reduce_init_kernels,
|
||||
kernel_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
std::string op_type = op_name(out);
|
||||
op_type[0] = std::toupper(op_name(out)[0]);
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
kernel_source << get_template_definition(
|
||||
kernel_name, "init_reduce", out_type, op);
|
||||
lib = d.get_library(kernel_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -347,14 +347,22 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
|
||||
<< fmt::format(
|
||||
non_atomic ? reduce_non_atomic_kernels
|
||||
: reduce_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_type);
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
std::vector<std::pair<std::string, std::string>> reduce_kernels = {
|
||||
{"all_reduce", "allReduce"},
|
||||
{"col_reduce_small", "colReduceSmall"},
|
||||
{"col_reduce_looped", "colReduceLooped"},
|
||||
{"row_reduce_small", "rowReduceSmall"},
|
||||
{"row_reduce_looped", "rowReduceLooped"},
|
||||
{"row_reduce_simple", "rowReduceSimple"}};
|
||||
std::string op = op_type + "<" + out_type + ">";
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
|
||||
for (auto [func, name] : reduce_kernels) {
|
||||
kernel_source << get_template_definition(
|
||||
name + "_" + lib_name, func, in_type, out_type, op);
|
||||
}
|
||||
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
@@ -37,13 +37,13 @@ struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC T
|
||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
|
||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
|
||||
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
|
||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
|
||||
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
@@ -51,13 +51,15 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_and_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
|
||||
METAL_FUNC void mlx_atomic_fetch_or_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
size_t offset) {
|
||||
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
@@ -65,7 +67,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_min_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
@@ -73,7 +75,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_max_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
@@ -81,7 +83,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_add_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
|
||||
}
|
||||
|
||||
@@ -89,7 +91,7 @@ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
T expected = mlx_atomic_load_explicit(object, offset);
|
||||
while (!mlx_atomic_compare_exchange_weak_explicit(
|
||||
object, &expected, val * expected, offset)) {
|
||||
@@ -101,7 +103,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
thread T* expected,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
return atomic_compare_exchange_weak_explicit(
|
||||
&(object[offset].val),
|
||||
expected,
|
||||
@@ -115,7 +117,7 @@ template <>
|
||||
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
|
||||
device mlx_atomic<float>* object,
|
||||
float val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
float expected = mlx_atomic_load_explicit(object, offset);
|
||||
while (val < expected) {
|
||||
if (mlx_atomic_compare_exchange_weak_explicit(
|
||||
@@ -130,7 +132,7 @@ template <>
|
||||
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
|
||||
device mlx_atomic<float>* object,
|
||||
float val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
float expected = mlx_atomic_load_explicit(object, offset);
|
||||
while (val > expected) {
|
||||
if (mlx_atomic_compare_exchange_weak_explicit(
|
||||
@@ -157,7 +159,7 @@ union uint_or_packed {
|
||||
|
||||
template <typename T, typename Op>
|
||||
struct mlx_atomic_update_helper {
|
||||
uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
|
||||
uint operator()(uint_or_packed<T> init, T update, size_t elem_offset) {
|
||||
Op op;
|
||||
init.val[elem_offset] = op(update, init.val[elem_offset]);
|
||||
return init.bits;
|
||||
@@ -168,9 +170,9 @@ template <typename T, typename Op>
|
||||
METAL_FUNC void mlx_atomic_update_and_store(
|
||||
device mlx_atomic<T>* object,
|
||||
T update,
|
||||
uint offset) {
|
||||
uint pack_offset = offset / packing_size<T>;
|
||||
uint elem_offset = offset % packing_size<T>;
|
||||
size_t offset) {
|
||||
size_t pack_offset = offset / packing_size<T>;
|
||||
size_t elem_offset = offset % packing_size<T>;
|
||||
|
||||
mlx_atomic_update_helper<T, Op> helper;
|
||||
uint_or_packed<T> expected;
|
||||
@@ -251,9 +253,9 @@ struct __Min {
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC T
|
||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
|
||||
uint pack_offset = offset / sizeof(T);
|
||||
uint elem_offset = offset % sizeof(T);
|
||||
mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
|
||||
size_t pack_offset = offset / sizeof(T);
|
||||
size_t elem_offset = offset % sizeof(T);
|
||||
uint_or_packed<T> packed_val;
|
||||
packed_val.bits =
|
||||
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
|
||||
@@ -262,7 +264,7 @@ mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
|
||||
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
|
||||
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
@@ -270,9 +272,9 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_and_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
uint pack_offset = offset / packing_size<T>;
|
||||
uint elem_offset = offset % packing_size<T>;
|
||||
size_t offset) {
|
||||
size_t pack_offset = offset / packing_size<T>;
|
||||
size_t elem_offset = offset % packing_size<T>;
|
||||
uint_or_packed<T> identity;
|
||||
identity.bits = __UINT32_MAX__;
|
||||
identity.val[elem_offset] = val;
|
||||
@@ -282,10 +284,12 @@ METAL_FUNC void mlx_atomic_fetch_and_explicit(
|
||||
}
|
||||
|
||||
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void
|
||||
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
|
||||
uint pack_offset = offset / packing_size<T>;
|
||||
uint elem_offset = offset % packing_size<T>;
|
||||
METAL_FUNC void mlx_atomic_fetch_or_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
size_t offset) {
|
||||
size_t pack_offset = offset / packing_size<T>;
|
||||
size_t elem_offset = offset % packing_size<T>;
|
||||
uint_or_packed<T> identity;
|
||||
identity.bits = 0;
|
||||
identity.val[elem_offset] = val;
|
||||
@@ -298,7 +302,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_min_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
@@ -306,7 +310,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_max_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
@@ -314,7 +318,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_add_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
@@ -322,7 +326,7 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
|
||||
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
T val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
|
||||
}
|
||||
|
||||
@@ -331,7 +335,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
|
||||
device mlx_atomic<T>* object,
|
||||
thread uint* expected,
|
||||
uint val,
|
||||
uint offset) {
|
||||
size_t offset) {
|
||||
return atomic_compare_exchange_weak_explicit(
|
||||
&(object[offset].val),
|
||||
expected,
|
||||
|
@@ -23,6 +23,8 @@ struct complex64_t {
|
||||
|
||||
// Constructors
|
||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
|
||||
constexpr complex64_t() : real(0), imag(0) {};
|
||||
constexpr complex64_t() threadgroup : real(0), imag(0) {};
|
||||
|
||||
// Conversions to complex64_t
|
||||
template <
|
||||
|
@@ -9,7 +9,8 @@
|
||||
#endif
|
||||
|
||||
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_WRITES = 4;
|
||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||
static MTL_CONST constexpr int RMS_N_READS = 4;
|
||||
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
||||
|
@@ -14,32 +14,36 @@ METAL_FUNC void gather_impl(
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto ind_idx = index.x;
|
||||
auto ind_offset = index.y;
|
||||
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
idx_loc = index.x * indices.strides[indices.ndim * i];
|
||||
idx_loc += elem_to_loc(
|
||||
index.y,
|
||||
&indices.shapes[indices.ndim * i + 1],
|
||||
&indices.strides[indices.ndim * i + 1],
|
||||
indices.ndim - 1);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||
size_t out_idx = index.z;
|
||||
if (IDX_NDIM == 1) {
|
||||
out_idx += static_cast<size_t>(grid_dim.z) * index.x;
|
||||
} else if (IDX_NDIM >= 2) {
|
||||
out_idx +=
|
||||
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
|
||||
}
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
}
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#pragma once
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_all.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_col.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_init.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_row.h"
|
||||
|
@@ -8,7 +8,6 @@
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/atomic.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_init.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
|
||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
@@ -28,7 +27,8 @@
|
||||
|
||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||
inst_f(name, int64, int64_t, op) \
|
||||
inst_f(name, uint64, uint64_t, op)
|
||||
inst_f(name, uint64, uint64_t, op) \
|
||||
inst_f(name, complex64, complex64_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
@@ -83,9 +83,9 @@
|
||||
op)
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i_reduce_" #name)]] [[kernel]] void \
|
||||
init_reduce<otype, op>( \
|
||||
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
|
||||
instantiate_kernel("init_reduce_" #name, \
|
||||
init_reduce, \
|
||||
otype, op)
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name##tname, type, op<type>)
|
||||
@@ -97,40 +97,15 @@ instantiate_init_reduce(andbool_, bool, And<bool>)
|
||||
instantiate_init_reduce(orbool_, bool, Or<bool>)
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("allNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
instantiate_kernel("allReduce_" #name, \
|
||||
all_reduce, \
|
||||
itype, otype, op)
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
@@ -138,153 +113,72 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("colGeneral_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("colReduceSmall_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("colGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("colReduceLooped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, dim, bm, bn)
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("colSmall_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(8)]], \
|
||||
const constant int* non_col_shapes [[buffer(9)]], \
|
||||
const constant size_t* non_col_strides [[buffer(10)]], \
|
||||
const constant int& non_col_ndim [[buffer(11)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
|
||||
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_col_reduce_looped(name, itype, otype, op, 4)
|
||||
|
||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_general(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or<bool>)
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("rowReduceSmall_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("rowGeneralSmall_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint lid [[thread_position_in_grid]]); \
|
||||
template [[host_name("rowGeneralMed_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_med<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("rowReduceLooped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, dim)
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("rowGeneral_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("rowGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op, 4) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 1) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 2) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 3) \
|
||||
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
|
||||
instantiate_kernel("rowReduceSimple_" #name, \
|
||||
row_reduce_simple, \
|
||||
itype, otype, op)
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
|
||||
|
@@ -5,6 +5,20 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#define DEFINE_SIMD_REDUCE() \
|
||||
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
|
||||
T simd_reduce(T val) { \
|
||||
return simd_reduce_impl(val); \
|
||||
} \
|
||||
\
|
||||
template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
|
||||
T simd_reduce(T val) { \
|
||||
for (short i = simd_size / 2; i > 0; i /= 2) { \
|
||||
val = operator()(val, simd_shuffle_down(val, i)); \
|
||||
} \
|
||||
return val; \
|
||||
}
|
||||
|
||||
static constant constexpr const uint8_t simd_size = 32;
|
||||
|
||||
union bool4_or_uint {
|
||||
@@ -14,14 +28,16 @@ union bool4_or_uint {
|
||||
|
||||
struct None {
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
|
||||
mlx_atomic_store_explicit(out, val, offset);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U = bool>
|
||||
struct And {
|
||||
bool simd_reduce(bool val) {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
bool simd_reduce_impl(bool val) {
|
||||
return simd_all(val);
|
||||
}
|
||||
|
||||
@@ -31,7 +47,7 @@ struct And {
|
||||
device mlx_atomic<unsigned int>* out,
|
||||
bool val,
|
||||
int elem_idx,
|
||||
int offset = 0) {
|
||||
size_t offset = 0) {
|
||||
if (!val) {
|
||||
bool4_or_uint update;
|
||||
update.b = {true, true, true, true};
|
||||
@@ -40,7 +56,8 @@ struct And {
|
||||
}
|
||||
}
|
||||
|
||||
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
|
||||
void
|
||||
atomic_update(device mlx_atomic<bool>* out, bool val, size_t offset = 0) {
|
||||
if (!val) {
|
||||
mlx_atomic_store_explicit(out, val, offset);
|
||||
}
|
||||
@@ -59,7 +76,9 @@ struct And {
|
||||
|
||||
template <typename U = bool>
|
||||
struct Or {
|
||||
bool simd_reduce(bool val) {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
bool simd_reduce_impl(bool val) {
|
||||
return simd_any(val);
|
||||
}
|
||||
|
||||
@@ -68,8 +87,8 @@ struct Or {
|
||||
void atomic_update(
|
||||
device mlx_atomic<unsigned int>* out,
|
||||
bool val,
|
||||
uint elem_idx,
|
||||
uint offset = 0) {
|
||||
int elem_idx,
|
||||
size_t offset = 0) {
|
||||
if (val) {
|
||||
bool4_or_uint update;
|
||||
update.b = {false, false, false, false};
|
||||
@@ -78,7 +97,8 @@ struct Or {
|
||||
}
|
||||
}
|
||||
|
||||
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
|
||||
void
|
||||
atomic_update(device mlx_atomic<bool>* out, bool val, size_t offset = 0) {
|
||||
if (val) {
|
||||
mlx_atomic_store_explicit(out, val, offset);
|
||||
}
|
||||
@@ -97,15 +117,17 @@ struct Or {
|
||||
|
||||
template <typename U>
|
||||
struct Sum {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
T simd_reduce_impl(T val) {
|
||||
return simd_sum(val);
|
||||
}
|
||||
|
||||
static constexpr constant U init = U(0);
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
|
||||
mlx_atomic_fetch_add_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
@@ -117,15 +139,17 @@ struct Sum {
|
||||
|
||||
template <typename U>
|
||||
struct Prod {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
T simd_reduce_impl(T val) {
|
||||
return simd_product(val);
|
||||
}
|
||||
|
||||
static constexpr constant U init = U(1);
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
|
||||
mlx_atomic_fetch_mul_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
@@ -137,15 +161,17 @@ struct Prod {
|
||||
|
||||
template <typename U>
|
||||
struct Min {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
T simd_reduce_impl(T val) {
|
||||
return simd_min(val);
|
||||
}
|
||||
|
||||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
|
||||
mlx_atomic_fetch_min_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
@@ -157,15 +183,17 @@ struct Min {
|
||||
|
||||
template <typename U>
|
||||
struct Max {
|
||||
DEFINE_SIMD_REDUCE()
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
T simd_reduce_impl(T val) {
|
||||
return simd_max(val);
|
||||
}
|
||||
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
|
||||
mlx_atomic_fetch_max_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
|
@@ -1,135 +1,61 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
METAL_FUNC U per_thread_all_reduce(
|
||||
const device T* in,
|
||||
const device size_t& in_size,
|
||||
uint gid,
|
||||
uint grid_size) {
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
|
||||
if (gid * N_READS < in_size) {
|
||||
in += gid * N_READS;
|
||||
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
|
||||
U vals[N_READS] = {op.init};
|
||||
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = static_cast<U>(in[i]);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total_val = op(vals[i], total_val);
|
||||
}
|
||||
|
||||
in += grid_size * N_READS;
|
||||
}
|
||||
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||
if (curr_idx < in_size) {
|
||||
int max_reads = in_size - curr_idx;
|
||||
T vals[N_READS];
|
||||
|
||||
for (int i = 0, idx = 0; i < N_READS; i++, idx++) {
|
||||
idx = idx < max_reads ? idx : max_reads - 1;
|
||||
vals[i] = in[idx];
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
U val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return total_val;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// NB: This kernel assumes threads_per_threadgroup is at most
|
||||
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||
// complete the reduction in two steps of simd-level reductions.
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce(
|
||||
const device T* in [[buffer(0)]],
|
||||
device mlx_atomic<U>* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& in_size [[buffer(2)]],
|
||||
const constant size_t& row_size [[buffer(3)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
threadgroup U shared_vals[simd_size];
|
||||
|
||||
U total_val =
|
||||
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
U total = Op::init;
|
||||
int64_t start_idx = gid.y * row_size;
|
||||
int64_t actual_row =
|
||||
(start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
|
||||
int64_t blocks = actual_row / (lsize.x * N_READS);
|
||||
int extra = actual_row - blocks * (lsize.x * N_READS);
|
||||
extra -= lid.x * N_READS;
|
||||
start_idx += lid.x * N_READS;
|
||||
in += start_idx;
|
||||
|
||||
if (extra >= N_READS) {
|
||||
blocks++;
|
||||
extra = 0;
|
||||
}
|
||||
|
||||
for (int64_t b = 0; b < blocks; b++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total = op(static_cast<U>(in[i]), total);
|
||||
}
|
||||
in += lsize.x * N_READS;
|
||||
}
|
||||
if (extra > 0) {
|
||||
for (int i = 0; i < extra; i++) {
|
||||
total = op(static_cast<U>(in[i]), total);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduction within simd group
|
||||
total_val = op.simd_reduce(total_val);
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
total = op.simd_reduce(total);
|
||||
if (simd_per_group > 1) {
|
||||
if (simd_lane_id == 0) {
|
||||
shared_vals[simd_group_id] = total;
|
||||
}
|
||||
|
||||
// Reduction within thread group
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init;
|
||||
total = op.simd_reduce(total);
|
||||
}
|
||||
|
||||
// Reduction within thread group
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
// Reduction across threadgroups
|
||||
if (lid == 0) {
|
||||
op.atomic_update(out, total_val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce_no_atomics(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]) {
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val =
|
||||
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
|
||||
// Reduction within simd group (simd_add isn't supported for uint64/int64
|
||||
// types)
|
||||
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
|
||||
lane_offset /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||
}
|
||||
// Write simd group reduction results to local memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction of simdgroup reduction results within threadgroup.
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
|
||||
lane_offset /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||
}
|
||||
|
||||
// Reduction across threadgroups
|
||||
if (lid == 0) {
|
||||
out[thread_group_id] = total_val;
|
||||
if (lid.x == 0) {
|
||||
out[gid.y] = total;
|
||||
}
|
||||
}
|
||||
|
@@ -1,165 +1,337 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Small column reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS = 0,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
// Appease the compiler
|
||||
(void)out_size;
|
||||
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[thread_position_in_grid]],
|
||||
uint3 tsize [[threads_per_grid]]) {
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
const device T* row;
|
||||
|
||||
auto out_idx = tid;
|
||||
// Case 1: Small row small column
|
||||
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
|
||||
U totals[31];
|
||||
for (int i = 0; i < 31; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
in += elem_to_loc(
|
||||
out_idx,
|
||||
shape + non_col_ndim,
|
||||
strides + non_col_ndim,
|
||||
ndim - non_col_ndim);
|
||||
short stride = reduction_stride;
|
||||
short size = reduction_size;
|
||||
short blocks = stride / N_READS;
|
||||
short extra = stride - blocks * N_READS;
|
||||
|
||||
for (uint i = 0; i < non_col_reductions; i++) {
|
||||
size_t in_idx =
|
||||
elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
|
||||
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
|
||||
for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
|
||||
U val = static_cast<U>(in[in_idx]);
|
||||
total_val = op(total_val, val);
|
||||
for (uint r = 0; r < non_col_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
for (short i = 0; i < size; i++) {
|
||||
for (short j = 0; j < blocks; j++) {
|
||||
for (short k = 0; k < N_READS; k++) {
|
||||
totals[j * N_READS + k] =
|
||||
op(totals[j * N_READS + k],
|
||||
static_cast<U>(row[i * stride + j * N_READS + k]));
|
||||
}
|
||||
}
|
||||
for (short k = 0; k < extra; k++) {
|
||||
totals[blocks * N_READS + k] =
|
||||
op(totals[blocks * N_READS + k],
|
||||
static_cast<U>(row[i * stride + blocks * N_READS + k]));
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
out += out_idx * reduction_stride;
|
||||
for (short j = 0; j < stride; j++) {
|
||||
out[j] = totals[j];
|
||||
}
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
// Case 2: Long row small column
|
||||
else if (reduction_size * non_col_reductions < 32) {
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
short size = reduction_size;
|
||||
size_t offset = size_t(tid.x) * N_READS;
|
||||
bool safe = offset + N_READS <= reduction_stride;
|
||||
short extra = reduction_stride - offset;
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
METAL_FUNC U _contiguous_strided_reduce(
|
||||
const device T* in,
|
||||
threadgroup U* local_data,
|
||||
uint in_idx,
|
||||
uint reduction_size,
|
||||
uint reduction_stride,
|
||||
uint2 tid,
|
||||
uint2 lid,
|
||||
uint2 lsize) {
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
size_t out_idx = tid.y + tsize.z * size_t(tid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
|
||||
|
||||
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
|
||||
for (uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
|
||||
uint offset = base_offset + r;
|
||||
total_val =
|
||||
op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
|
||||
}
|
||||
local_data[lsize.y * lid.x + lid.y] = total_val;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint r = 0; r < non_col_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
U val = Op::init;
|
||||
if (lid.y == 0) {
|
||||
// Perform reduction across columns in thread group
|
||||
for (uint i = 0; i < lsize.y; i++) {
|
||||
val = op(val, local_data[lsize.y * lid.x + i]);
|
||||
if (safe) {
|
||||
for (short i = 0; i < size; i++) {
|
||||
for (short j = 0; j < N_READS; j++) {
|
||||
totals[j] =
|
||||
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < size; i++) {
|
||||
for (short j = 0; j < extra; j++) {
|
||||
totals[j] =
|
||||
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
out += out_idx * reduction_stride + offset;
|
||||
if (safe) {
|
||||
for (short i = 0; i < N_READS; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < extra; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
// Case 3: Long row medium column
|
||||
else {
|
||||
threadgroup U shared_vals[1024];
|
||||
U totals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
short stride = reduction_stride;
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 tile((stride + N_READS - 1) / N_READS, 32);
|
||||
short2 offset((lid % tile.x) * N_READS, lid / tile.x);
|
||||
short sm_stride = tile.x * N_READS;
|
||||
bool safe = offset.x + N_READS <= stride;
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_general(
|
||||
const device T* in [[buffer(0)]],
|
||||
device mlx_atomic<U>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
|
||||
|
||||
Op op;
|
||||
if (out_idx < out_size) {
|
||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
// Read cooperatively and contiguously and aggregate the partial results.
|
||||
size_t total = non_col_reductions * reduction_size;
|
||||
loop.next(offset.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = offset.y; r < total; r += simd_size) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
// Write out reduction results generated by threadgroups working on specific
|
||||
// output element, contiguously.
|
||||
if (lid.y == 0) {
|
||||
op.atomic_update(out, val, out_idx);
|
||||
if (safe) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||
}
|
||||
} else {
|
||||
U vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init;
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(simd_size, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
// Each thread holds N_READS partial results but the simdgroups are not
|
||||
// aligned to do the reduction across the simdgroup so we write our results
|
||||
// in the shared memory and read them back according to the simdgroup.
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[i] = op.simd_reduce(
|
||||
shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
|
||||
}
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
short column = simd_group_id * N_READS;
|
||||
out += out_idx * reduction_stride + column;
|
||||
if (column + N_READS <= stride) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; column + i < stride; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_general_no_atomics(
|
||||
/**
|
||||
* Our approach is the following simple looped approach:
|
||||
* 1. Each thread keeps running totals for BN / n_simdgroups outputs.
|
||||
* 2. Load a tile BM, BN in registers and accumulate in the running totals
|
||||
* 3. Move ahead by BM steps until the column axis and the non column
|
||||
* reductions are exhausted.
|
||||
* 6. If BM == 32 then transpose in SM and simd reduce the running totals.
|
||||
* Otherwise write in shared memory and BN threads accumulate the running
|
||||
* totals with a loop.
|
||||
* 7. Write them to the output
|
||||
*/
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS = 0,
|
||||
int BM = 8,
|
||||
int BN = 128>
|
||||
[[kernel]] void col_reduce_looped(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
const constant size_t& non_col_reductions [[buffer(10)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
constexpr int n_simdgroups = 4;
|
||||
constexpr short tgp_size = n_simdgroups * simd_size;
|
||||
constexpr short n_reads = (BM * BN) / tgp_size;
|
||||
constexpr short n_read_blocks = BN / n_reads;
|
||||
|
||||
if (out_idx < out_size) {
|
||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
threadgroup U shared_vals[BN * BM];
|
||||
U totals[n_reads];
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
const device T* row;
|
||||
|
||||
// Write out reduction results generated by threadgroups working on specific
|
||||
// output element, contiguously.
|
||||
if (lid.y == 0) {
|
||||
uint tgsize_y = ceildiv(gsize.y, lsize.y);
|
||||
uint tgsize_z = ceildiv(gsize.z, lsize.z);
|
||||
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
short lid = simd_group_id * simd_size + simd_lane_id;
|
||||
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
|
||||
size_t column = BN * gid.x + offset.x;
|
||||
bool safe = column + n_reads <= reduction_stride;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
in += in_idx + column;
|
||||
|
||||
size_t total = non_col_reductions * reduction_size;
|
||||
loop.next(offset.y, reduce_shape, reduce_strides);
|
||||
for (size_t r = offset.y; r < total; r += BM) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
||||
}
|
||||
} else {
|
||||
U vals[n_reads];
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
vals[i] =
|
||||
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
|
||||
}
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
totals[i] = op(vals[i], totals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(BM, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
// We can use a simd reduction to accumulate across BM so each thread writes
|
||||
// the partial output to SM and then each simdgroup does BN / n_simdgroups
|
||||
// accumulations.
|
||||
if (BM == 32) {
|
||||
constexpr int n_outputs = BN / n_simdgroups;
|
||||
static_assert(
|
||||
BM != 32 || n_outputs == n_reads,
|
||||
"The tile should be selected such that n_outputs == n_reads");
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
shared_vals[offset.y * BN + offset.x + i] = totals[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
totals[i] =
|
||||
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
|
||||
}
|
||||
|
||||
// Write the output.
|
||||
if (simd_lane_id == 0) {
|
||||
size_t out_column = BN * gid.x + out_offset.x;
|
||||
out += out_idx * reduction_stride + out_column;
|
||||
if (out_column + n_outputs <= reduction_stride) {
|
||||
for (int i = 0; i < n_outputs; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; out_column + i < reduction_stride; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Each thread holds n_reads partial results. We write them all out to shared
|
||||
// memory and threads with offset.y == 0 aggregate the columns and write the
|
||||
// outputs.
|
||||
else {
|
||||
short x_block = offset.x / n_reads;
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (offset.y == 0) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
for (int j = 1; j < BM; j++) {
|
||||
totals[i] =
|
||||
op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write the output.
|
||||
if (offset.y == 0) {
|
||||
out += out_idx * reduction_stride + column;
|
||||
if (safe) {
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; column + i < reduction_stride; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,287 +1,366 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Small row reductions
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Row reduction utilities
|
||||
// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup
|
||||
// - `threadgroup_reduce` collaborative reduction in the threadgroup such that
|
||||
// lid.x == 0 holds the reduced value
|
||||
// - `thread_reduce` simple loop and reduce the row
|
||||
|
||||
// Each thread reduces for one output
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void row_reduce_general_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]) {
|
||||
/**
|
||||
* The thread group collaboratively reduces across the rows with bounds
|
||||
* checking. In the end each thread holds a part of the reduction.
|
||||
*/
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
METAL_FUNC void per_thread_row_reduce(
|
||||
thread U totals[N_WRITES],
|
||||
const device T* inputs[N_WRITES],
|
||||
int blocks,
|
||||
int extra,
|
||||
uint lsize_x,
|
||||
uint lid_x) {
|
||||
Op op;
|
||||
|
||||
uint out_idx = lid;
|
||||
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
// Set up the accumulator registers
|
||||
for (int i = 0; i < N_WRITES; i++) {
|
||||
totals[i] = Op::init;
|
||||
}
|
||||
|
||||
U total_val = Op::init;
|
||||
// Loop over the reduction size within thread group
|
||||
for (int i = 0; i < blocks; i++) {
|
||||
for (int j = 0; j < N_WRITES; j++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
|
||||
}
|
||||
|
||||
for (short r = 0; r < short(non_row_reductions); r++) {
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T* in_row = in + in_idx;
|
||||
|
||||
for (short i = 0; i < short(reduction_size); i++) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
inputs[j] += lsize_x * N_READS;
|
||||
}
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
// Each simdgroup reduces for one output
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void row_reduce_general_med(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
|
||||
uint out_idx = simd_per_group * tid + simd_group_id;
|
||||
|
||||
if (out_idx >= out_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
U total_val = Op::init;
|
||||
|
||||
if (short(non_row_reductions) == 1) {
|
||||
uint in_idx = elem_to_loc(out_idx, shape, strides, ndim);
|
||||
const device T* in_row = in + in_idx;
|
||||
|
||||
for (short i = simd_lane_id; i < short(reduction_size); i += 32) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
else if (short(non_row_reductions) >= 32) {
|
||||
for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T* in_row = in + in_idx;
|
||||
|
||||
for (short i = 0; i < short(reduction_size); i++) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
// Separate case for the last set as we close the reduction size
|
||||
int index = lid_x * N_READS;
|
||||
if (index + N_READS <= extra) {
|
||||
for (int j = 0; j < N_WRITES; j++) {
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
else {
|
||||
const short n_reductions =
|
||||
short(reduction_size) * short(non_row_reductions);
|
||||
const short reductions_per_thread =
|
||||
(n_reductions + simd_size - 1) / simd_size;
|
||||
|
||||
const short r_st = simd_lane_id / reductions_per_thread;
|
||||
const short r_ed = short(non_row_reductions);
|
||||
const short r_jump = simd_size / reductions_per_thread;
|
||||
|
||||
const short i_st = simd_lane_id % reductions_per_thread;
|
||||
const short i_ed = short(reduction_size);
|
||||
const short i_jump = reductions_per_thread;
|
||||
|
||||
if (r_st < r_jump) {
|
||||
for (short r = r_st; r < r_ed; r += r_jump) {
|
||||
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
|
||||
const device T* in_row = in + in_idx;
|
||||
|
||||
for (short i = i_st; i < i_ed; i += i_jump) {
|
||||
total_val = op(static_cast<U>(in_row[i]), total_val);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < N_WRITES; j++) {
|
||||
for (int i = 0; index + i < extra; i++) {
|
||||
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
if (simd_lane_id == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Large row reductions
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
METAL_FUNC U per_thread_row_reduce(
|
||||
/**
|
||||
* Consecutive rows in a contiguous array.
|
||||
*/
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
METAL_FUNC void per_thread_row_reduce(
|
||||
thread U totals[N_WRITES],
|
||||
const device T* in,
|
||||
const constant size_t& reduction_size,
|
||||
const constant size_t& out_size,
|
||||
int blocks,
|
||||
int extra,
|
||||
uint lsize_x,
|
||||
uint lid_x) {
|
||||
// Set up the input pointers
|
||||
const device T* inputs[N_WRITES];
|
||||
inputs[0] = in + lid_x * N_READS;
|
||||
for (int i = 1; i < N_READS; i++) {
|
||||
inputs[i] = inputs[i - 1] + reduction_size;
|
||||
}
|
||||
|
||||
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
|
||||
totals, inputs, blocks, extra, lsize_x, lid_x);
|
||||
}
|
||||
|
||||
/**
|
||||
* Consecutive rows in an arbitrarily ordered array.
|
||||
*/
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
METAL_FUNC void per_thread_row_reduce(
|
||||
thread U totals[N_WRITES],
|
||||
const device T* in,
|
||||
const size_t row_idx,
|
||||
int blocks,
|
||||
int extra,
|
||||
const constant int* shape,
|
||||
const constant size_t* strides,
|
||||
const constant int& ndim,
|
||||
uint lsize_x,
|
||||
uint lid_x,
|
||||
uint2 tid) {
|
||||
Op op;
|
||||
|
||||
// Each threadgroup handles 1 reduction
|
||||
// TODO: Specializing elem_to_loc would be slightly faster
|
||||
int idx = tid.y * out_size + tid.x;
|
||||
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
||||
in += extra_offset + lid_x * N_READS;
|
||||
|
||||
// The reduction is accumulated here
|
||||
U total_val = Op::init;
|
||||
|
||||
// Loop over the reduction size within thread group
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(reduction_size, N_READS * lsize_x) - 1; r++) {
|
||||
T vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[i];
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||
}
|
||||
|
||||
in += lsize_x * N_READS;
|
||||
uint lid_x) {
|
||||
// Set up the input pointers
|
||||
const device T* inputs[N_WRITES];
|
||||
in += lid_x * N_READS;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim);
|
||||
}
|
||||
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
|
||||
if (reduction_index < reduction_size) {
|
||||
int max_reads = reduction_size - reduction_index;
|
||||
|
||||
T vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
int idx = min(i, max_reads - 1);
|
||||
vals[i] = static_cast<U>(in[idx]);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
T val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
return total_val;
|
||||
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
|
||||
totals, inputs, blocks, extra, lsize_x, lid_x);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_general(
|
||||
const device T* in [[buffer(0)]],
|
||||
device mlx_atomic<U>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
/**
|
||||
* Reduce within the threadgroup.
|
||||
*/
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
METAL_FUNC void threadgroup_reduce(
|
||||
thread U totals[N_WRITES],
|
||||
threadgroup U* shared_vals,
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
(void)non_row_reductions;
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
reduction_size,
|
||||
out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim,
|
||||
lsize.x,
|
||||
lid.x,
|
||||
tid.xy);
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
// Prepare next level
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
// Simdgroup first
|
||||
for (int i = 0; i < N_WRITES; i++) {
|
||||
totals[i] = op.simd_reduce(totals[i]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if multiple simd groups
|
||||
if (reduction_size > simd_size) {
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
total_val = op.simd_reduce(total_val);
|
||||
}
|
||||
// Update output
|
||||
if (lid.x == 0) {
|
||||
op.atomic_update(out, total_val, tid.x);
|
||||
// Across simdgroups
|
||||
if (simd_per_group > 1) {
|
||||
if (simd_lane_id == 0) {
|
||||
for (int i = 0; i < N_WRITES; i++) {
|
||||
shared_vals[simd_group_id * N_WRITES + i] = totals[i];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
U values[N_WRITES];
|
||||
for (int i = 0; i < N_WRITES; i++) {
|
||||
values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i]
|
||||
: op.init;
|
||||
}
|
||||
|
||||
for (int i = 0; i < N_WRITES; i++) {
|
||||
totals[i] = op.simd_reduce(values[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_general_no_atomics(
|
||||
METAL_FUNC void
|
||||
thread_reduce(thread U& total, const device T* row, int blocks, int extra) {
|
||||
Op op;
|
||||
for (int i = 0; i < blocks; i++) {
|
||||
U vals[N_READS];
|
||||
for (int j = 0; j < N_READS; j++) {
|
||||
vals[j] = row[j];
|
||||
}
|
||||
for (int j = 0; j < N_READS; j++) {
|
||||
total = op(vals[j], total);
|
||||
}
|
||||
row += N_READS;
|
||||
}
|
||||
for (int i = 0; i < extra; i++) {
|
||||
total = op(*row++, total);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduction kernels
|
||||
// - `row_reduce_small` depending on the non-row reductions and row size it
|
||||
// either just loops over everything or a simd collaboratively reduces the
|
||||
// non_row reductions. In the first case one thread is responsible for one
|
||||
// output on the 2nd one simd is responsible for one output.
|
||||
// - `row_reduce_simple` simple contiguous row reduction
|
||||
// - `row_reduce_looped` simply loop and reduce each row for each non-row
|
||||
// reduction. One threadgroup is responsible for one output.
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS = 0,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_small(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& row_size [[buffer(2)]],
|
||||
const constant size_t& non_row_reductions [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint3 tid [[thread_position_in_grid]],
|
||||
uint3 tsize [[threads_per_grid]]) {
|
||||
Op op;
|
||||
|
||||
U total_val = Op::init;
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
|
||||
// Precompute some row reduction numbers
|
||||
const device T* row;
|
||||
int blocks = row_size / N_READS;
|
||||
int extra = row_size % N_READS;
|
||||
|
||||
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
|
||||
// Simple loop over non_row_reductions and reduce the row in the thread.
|
||||
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
|
||||
for (uint r = 0; r < non_row_reductions; r++) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
out[out_idx] = total_val;
|
||||
} else {
|
||||
// Collaboratively reduce over non_row_reductions in the simdgroup. Each
|
||||
// thread reduces every 32nd row and then a simple simd reduce.
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim);
|
||||
|
||||
loop.next(simd_lane_id, reduce_shape, reduce_strides);
|
||||
|
||||
for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
|
||||
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
|
||||
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
|
||||
loop.next(simd_size, reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
if (simd_lane_id == 0) {
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int N_READS = REDUCE_N_READS,
|
||||
int N_WRITES = REDUCE_N_WRITES>
|
||||
[[kernel]] void row_reduce_simple(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
(void)non_row_reductions;
|
||||
threadgroup U shared_vals[simd_size * N_WRITES];
|
||||
U totals[N_WRITES];
|
||||
|
||||
Op op;
|
||||
|
||||
threadgroup U local_vals[simd_size];
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
reduction_size,
|
||||
out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim,
|
||||
lsize.x,
|
||||
lid.x,
|
||||
tid.xy);
|
||||
|
||||
// Reduction within simd group - simd_add isn't supported for int64 types
|
||||
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||
// Move to the row
|
||||
size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z));
|
||||
if (out_idx + N_WRITES > out_size) {
|
||||
out_idx = out_size - N_WRITES;
|
||||
}
|
||||
in += out_idx * reduction_size;
|
||||
out += out_idx;
|
||||
|
||||
// Prepare next level
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Each thread reduces across the row
|
||||
int blocks = reduction_size / (lsize.x * N_READS);
|
||||
int extra = reduction_size - blocks * (lsize.x * N_READS);
|
||||
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
|
||||
totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if thread group has multiple simd groups
|
||||
if (ceildiv(reduction_size, N_READS) > simd_size) {
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||
// Reduce across the threadgroup
|
||||
threadgroup_reduce<T, U, Op, N_READS, N_WRITES>(
|
||||
totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
|
||||
|
||||
// Write the output
|
||||
if (lid.x == 0) {
|
||||
for (int i = 0; i < N_WRITES; i++) {
|
||||
out[i] = totals[i];
|
||||
}
|
||||
}
|
||||
// Write row reduce output for threadgroup with 1st thread in thread group
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIMS = 0,
|
||||
int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_looped(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t& row_size [[buffer(2)]],
|
||||
const constant size_t& non_row_reductions [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
const constant int* reduce_shape [[buffer(7)]],
|
||||
const constant size_t* reduce_strides [[buffer(8)]],
|
||||
const constant int& reduce_ndim [[buffer(9)]],
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
uint3 gsize [[threadgroups_per_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
threadgroup U shared_vals[simd_size];
|
||||
U total = Op::init;
|
||||
|
||||
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
|
||||
|
||||
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
|
||||
// needs a small refactor.
|
||||
in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS;
|
||||
|
||||
looped_elem_to_loc<NDIMS> loop;
|
||||
const device T* row;
|
||||
int blocks = row_size / (lsize.x * N_READS);
|
||||
int extra = row_size - blocks * (lsize.x * N_READS);
|
||||
|
||||
for (size_t i = 0; i < non_row_reductions; i++) {
|
||||
row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim);
|
||||
|
||||
// Each thread reduces across the row
|
||||
U row_total;
|
||||
per_thread_row_reduce<T, U, Op, N_READS, 1>(
|
||||
&row_total, &row, blocks, extra, lsize.x, lid.x);
|
||||
|
||||
// Aggregate across rows
|
||||
total = op(total, row_total);
|
||||
|
||||
loop.next(reduce_shape, reduce_strides);
|
||||
}
|
||||
|
||||
// Reduce across the threadgroup
|
||||
threadgroup_reduce<T, U, Op, N_READS, 1>(
|
||||
&total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
|
||||
|
||||
// Write the output
|
||||
if (lid.x == 0) {
|
||||
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
|
||||
out[out_idx] = total;
|
||||
}
|
||||
}
|
||||
|
@@ -4,44 +4,36 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope_single(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
void rope_single_impl(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
// Figure out L and d.
|
||||
uint2 pos,
|
||||
uint2 grid) {
|
||||
float L = scale * static_cast<float>(offset);
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * metal::exp2(-d * base);
|
||||
float theta = L * inv_freq;
|
||||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
// Compute the input and output indices
|
||||
uint in_index_1, in_index_2;
|
||||
uint out_index_1, out_index_2;
|
||||
uint index_1, index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x + pos.y * stride;
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 = 2 * pos.x + pos.y * stride;
|
||||
in_index_2 = in_index_1 + 1;
|
||||
index_1 = 2 * pos.x + pos.y * stride;
|
||||
index_2 = index_1 + 1;
|
||||
} else {
|
||||
out_index_1 = pos.x + pos.y * stride;
|
||||
out_index_2 = out_index_1 + grid.x;
|
||||
in_index_1 = pos.x + pos.y * stride;
|
||||
in_index_2 = in_index_1 + grid.x;
|
||||
index_1 = pos.x + pos.y * stride;
|
||||
index_2 = index_1 + grid.x;
|
||||
}
|
||||
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float x1 = static_cast<float>(in[index_1]);
|
||||
float x2 = static_cast<float>(in[index_2]);
|
||||
float rx1;
|
||||
float rx2;
|
||||
if (forward) {
|
||||
@@ -51,28 +43,58 @@ template <typename T, bool traditional, bool forward>
|
||||
rx1 = x2 * sintheta + x1 * costheta;
|
||||
rx2 = x2 * costheta - x1 * sintheta;
|
||||
}
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
out[index_1] = static_cast<T>(rx1);
|
||||
out[index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope(
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope_single(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
float inv_freq = metal::exp2(-d * base);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward>
|
||||
[[kernel]] void rope_single_freqs(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t& stride,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
uint2 pos [[thread_position_in_grid]],
|
||||
uint2 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
rope_single_impl<T, traditional, forward>(
|
||||
in, out, offset, inv_freq, scale, stride, pos, grid);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
void rope_impl(
|
||||
const device T* in,
|
||||
device T* out,
|
||||
constant const int& offset,
|
||||
const float inv_freq,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Figure out L and d.
|
||||
uint3 pos,
|
||||
uint3 grid) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * metal::exp2(-d * base);
|
||||
float theta = L * inv_freq;
|
||||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
@@ -116,37 +138,115 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
constant const float& base [[buffer(10)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
float inv_freq = metal::exp2(-d * base);
|
||||
rope_impl<T, traditional, forward, N>(
|
||||
in,
|
||||
out,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
|
||||
template <typename T, bool traditional, bool forward, int N = 4>
|
||||
[[kernel]] void rope_freqs(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
constant const int& offset,
|
||||
constant const float& scale,
|
||||
constant const size_t strides[3],
|
||||
constant const size_t out_strides[3],
|
||||
constant const size_t& n_batch,
|
||||
const device float* freqs [[buffer(10)]],
|
||||
constant const size_t& freq_stride [[buffer(11)]],
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
|
||||
rope_impl<T, traditional, forward, N>(
|
||||
in,
|
||||
out,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
pos,
|
||||
grid);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_rope_g(name, type, traditional, forward) \
|
||||
template [[host_name("rope_" #name)]] [[kernel]] void \
|
||||
rope<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_freqs_" #name)]] \
|
||||
[[kernel]] void rope_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t strides[3], \
|
||||
constant const size_t out_strides[3], \
|
||||
constant const size_t& n_batch, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
template [[host_name("rope_single_" #name)]] [[kernel]] void \
|
||||
rope_single<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
#define instantiate_rope_s(name, type, traditional, forward) \
|
||||
template [[host_name("rope_single_" #name)]] [[kernel]] void \
|
||||
rope_single<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
constant const float& base [[buffer(10)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]); \
|
||||
template [[host_name("rope_single_freqs_" #name)]] \
|
||||
[[kernel]] void rope_single_freqs<type, traditional, forward>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const int& offset, \
|
||||
constant const float& scale, \
|
||||
constant const size_t& stride, \
|
||||
const device float* freqs [[buffer(10)]], \
|
||||
constant const size_t& freq_stride [[buffer(11)]], \
|
||||
uint2 pos [[thread_position_in_grid]], \
|
||||
uint2 grid [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_rope(name, type, traditional, forward) \
|
||||
instantiate_rope_s(name, type, traditional, forward) \
|
||||
instantiate_rope_g(name, type, traditional, forward)
|
||||
instantiate_rope_g(name, type, traditional, forward)
|
||||
|
||||
// clang-format off
|
||||
instantiate_rope(traditional_float16, half, true, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
|
||||
instantiate_rope(traditional_float32, float, true, true)
|
||||
|
@@ -18,7 +18,7 @@ METAL_FUNC void scatter_1d_index_impl(
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
Op op;
|
||||
|
||||
uint out_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; i++) {
|
||||
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
|
||||
out_idx += idx_val * out_strides[i];
|
||||
|
@@ -394,7 +394,7 @@ struct Conv2DWeightBlockLoader {
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_ -> wt_strides[0]),
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
|
@@ -244,7 +244,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_ -> wt_strides[0]),
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
|
@@ -218,7 +218,7 @@ struct Conv2DWeightBlockLoaderGeneral {
|
||||
const short base_ww_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_ -> wt_strides[0]),
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
|
@@ -64,6 +64,7 @@ instantiate_unary_all(Cos, complex64, complex64_t)
|
||||
instantiate_unary_all(Cosh, complex64, complex64_t)
|
||||
instantiate_unary_all(Exp, complex64, complex64_t)
|
||||
instantiate_unary_all(Negative, complex64, complex64_t)
|
||||
instantiate_unary_all(Sign, complex64, complex64_t)
|
||||
instantiate_unary_all(Sin, complex64, complex64_t)
|
||||
instantiate_unary_all(Sinh, complex64, complex64_t)
|
||||
instantiate_unary_all(Tan, complex64, complex64_t)
|
||||
|
@@ -308,6 +308,14 @@ struct Sign {
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x != 0;
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
if (x == complex64_t(0)) {
|
||||
return x;
|
||||
}
|
||||
return x /
|
||||
(complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag);
|
||||
};
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
|
@@ -64,6 +64,16 @@ struct Limits<bool> {
|
||||
static constexpr constant bool min = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<complex64_t> {
|
||||
static constexpr constant complex64_t max = complex64_t(
|
||||
metal::numeric_limits<float>::infinity(),
|
||||
metal::numeric_limits<float>::infinity());
|
||||
static constexpr constant complex64_t min = complex64_t(
|
||||
-metal::numeric_limits<float>::infinity(),
|
||||
-metal::numeric_limits<float>::infinity());
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -101,6 +111,34 @@ METAL_FUNC stride_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
stride_t elem,
|
||||
device const int* shape,
|
||||
device const stride_t* strides,
|
||||
int ndim) {
|
||||
stride_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
stride_t elem,
|
||||
constant const int* shape,
|
||||
constant const stride_t* strides,
|
||||
int ndim) {
|
||||
stride_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
// Non templated version to handle arbitrary dims
|
||||
template <typename stride_t>
|
||||
METAL_FUNC stride_t elem_to_loc(
|
||||
@@ -288,12 +326,87 @@ METAL_FUNC uint3 elem_to_loc_3_nd(
|
||||
return loc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Elem to loc in a loop utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int dim, typename offset_t = size_t>
|
||||
struct looped_elem_to_loc {
|
||||
looped_elem_to_loc<dim - 1, offset_t> inner_looper;
|
||||
offset_t offset{0};
|
||||
int index{0};
|
||||
|
||||
void next(const constant int* shape, const constant size_t* strides) {
|
||||
index++;
|
||||
offset += strides[dim - 1];
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
}
|
||||
}
|
||||
|
||||
void next(int n, const constant int* shape, const constant size_t* strides) {
|
||||
index += n;
|
||||
offset += n * strides[dim - 1];
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
int extra = index - shape[dim - 1];
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
if (extra > 0) {
|
||||
next(extra, shape, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
offset_t
|
||||
location(offset_t, const constant int*, const constant size_t*, int) {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename offset_t>
|
||||
struct looped_elem_to_loc<1, offset_t> {
|
||||
offset_t offset{0};
|
||||
|
||||
void next(const constant int*, const constant size_t* strides) {
|
||||
offset += strides[0];
|
||||
}
|
||||
|
||||
void next(int n, const constant int*, const constant size_t* strides) {
|
||||
offset += n * strides[0];
|
||||
}
|
||||
|
||||
offset_t
|
||||
location(offset_t, const constant int*, const constant size_t*, int) {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename offset_t>
|
||||
struct looped_elem_to_loc<0, offset_t> {
|
||||
void next(const constant int*, const constant size_t*) {}
|
||||
void next(int, const constant int*, const constant size_t*) {}
|
||||
|
||||
offset_t location(
|
||||
offset_t idx,
|
||||
const constant int* shape,
|
||||
const constant size_t* strides,
|
||||
int ndim) {
|
||||
return elem_to_loc(idx, shape, strides, ndim);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Calculation utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** Compute ceil((float)N/(float)M) */
|
||||
inline size_t ceildiv(size_t N, size_t M) {
|
||||
template <typename T, typename U>
|
||||
inline T ceildiv(T N, U M) {
|
||||
return (N + M - 1) / M;
|
||||
}
|
||||
|
||||
@@ -339,3 +452,8 @@ inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
||||
inline bool simd_shuffle_down(bool data, uint16_t delta) {
|
||||
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
||||
}
|
||||
|
||||
inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {
|
||||
return complex64_t(
|
||||
simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));
|
||||
}
|
||||
|
@@ -14,9 +14,9 @@ SRC_NAME=$(basename -- "${SRC_FILE}")
|
||||
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
|
||||
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
|
||||
mkdir -p $OUTPUT_DIR
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
CONTENT=$($CC -I $SRC_DIR -DMLX_METAL_JIT -E -P $INPUT_FILE $CFLAGS 2>/dev/null)
|
||||
CONTENT=$($CC -I "$SRC_DIR" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
namespace mlx::core::metal {
|
||||
|
@@ -47,8 +47,6 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
for (auto& input : arr.inputs()) {
|
||||
if (input.event().valid() &&
|
||||
input.event().stream() != arr.primitive().stream()) {
|
||||
// TODO, consider committing the buffer and encoding a wait in the new
|
||||
// buffer rather than on the task thread
|
||||
input.event().wait();
|
||||
}
|
||||
}
|
||||
|
@@ -135,8 +135,8 @@ void RMSNormVJP::eval_gpu(
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and initialize the
|
||||
// gradient accumulator to 0.
|
||||
// Allocate the gradient accumulator gw and a temporary to store the
|
||||
// gradients before they are accumulated.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
@@ -146,11 +146,7 @@ void RMSNormVJP::eval_gpu(
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
{
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copies.push_back(std::move(zero));
|
||||
}
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = RMS_N_READS;
|
||||
@@ -330,8 +326,8 @@ void LayerNormVJP::eval_gpu(
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
// Allocate a temporary to store the gradients for w and initialize the
|
||||
// gradient accumulator to 0.
|
||||
// Allocate a temporary to store the gradients for w and allocate the output
|
||||
// gradient accumulators.
|
||||
array gw_temp({n_rows, x.shape().back()}, gw.dtype(), nullptr, {});
|
||||
bool g_in_gw = false;
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
@@ -341,12 +337,8 @@ void LayerNormVJP::eval_gpu(
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
}
|
||||
copies.push_back(gw_temp);
|
||||
{
|
||||
array zero(0, gw.dtype());
|
||||
copy_gpu(zero, gw, CopyType::Scalar, s);
|
||||
copy_gpu(zero, gb, CopyType::Scalar, s);
|
||||
copies.push_back(std::move(zero));
|
||||
}
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
|
||||
|
||||
// Finish with the gradient for b in case we had a b
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@@ -4,12 +4,14 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -197,7 +199,27 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
static Stream io_stream = new_stream(Device::cpu);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto read_task = [out = out,
|
||||
offset = offset_,
|
||||
reader = reader_,
|
||||
swap_endianness = swap_endianness_]() mutable {
|
||||
load(out, offset, reader, swap_endianness);
|
||||
};
|
||||
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
|
||||
auto signal_task = [out = out, fut = std::move(fut)]() {
|
||||
fut.wait();
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(io_stream, std::move(signal_task));
|
||||
auto& d = metal::device(stream().device);
|
||||
d.end_encoding(stream().index);
|
||||
auto command_buffer = d.get_command_buffer(stream().index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||
out.event().value());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,8 @@ void all_reduce_dispatch(
|
||||
const std::string& op_name,
|
||||
CommandEncoder& compute_encoder,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
const Stream& s,
|
||||
std::vector<array>& copies);
|
||||
|
||||
void row_reduce_general_dispatch(
|
||||
const array& in,
|
||||
|
@@ -67,8 +67,10 @@ void RoPE::eval_gpu(
|
||||
// Special case for inference (single time step and contiguous)
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
|
||||
bool with_freqs = inputs.size() == 2;
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_")
|
||||
kname << "rope_" << (single ? "single_" : "")
|
||||
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
|
||||
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -78,27 +80,36 @@ void RoPE::eval_gpu(
|
||||
compute_encoder.set_input_array(donated ? out : in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&base, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 3);
|
||||
|
||||
size_t n_batch = in.size() / mat_size;
|
||||
MTL::Size group_dims;
|
||||
MTL::Size grid_dims;
|
||||
if (single) {
|
||||
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(out_strides, sizeof(size_t), 4);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
auto group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
auto grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
group_dims = get_block_dims(dim0, n_batch, 1);
|
||||
grid_dims = MTL::Size(dim0, n_batch, 1);
|
||||
} else {
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&n_batch, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&n_batch, sizeof(size_t), 6);
|
||||
uint32_t dim0 = dims_ / 2;
|
||||
uint32_t dim1 = in.shape(-2);
|
||||
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
}
|
||||
|
||||
if (with_freqs) {
|
||||
auto& freqs = inputs[1];
|
||||
compute_encoder.set_input_array(freqs, 10);
|
||||
auto freq_stride = freqs.strides()[0];
|
||||
compute_encoder->setBytes(&freq_stride, sizeof(size_t), 11);
|
||||
} else {
|
||||
compute_encoder->setBytes(&base, sizeof(float), 10);
|
||||
}
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@@ -56,9 +56,12 @@ void ternary_op_gpu_inplace(
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
bool donate_c = c.data_shared_ptr() == nullptr;
|
||||
compute_encoder.set_input_array(donate_a ? out : a, 0);
|
||||
compute_encoder.set_input_array(donate_b ? out : b, 1);
|
||||
compute_encoder.set_input_array(donate_c ? out : c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
if (topt == TernaryOpType::General) {
|
||||
@@ -91,9 +94,10 @@ void ternary_op_gpu_inplace(
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
// Launch a 1D or 2D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
@@ -119,6 +120,14 @@ NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
} // namespace distributed
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -50,17 +50,4 @@ struct Group {
|
||||
*/
|
||||
Group init(bool strict = false);
|
||||
|
||||
namespace detail {
|
||||
|
||||
/* Return the communication stream. */
|
||||
Stream communication_stream();
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_sum(Group group, const array& input, array& output);
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_gather(Group group, const array& input, array& output);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
24
mlx/distributed/distributed_impl.h
Normal file
24
mlx/distributed/distributed_impl.h
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
|
||||
namespace mlx::core::distributed::detail {
|
||||
|
||||
/* Return the communication stream. */
|
||||
Stream communication_stream();
|
||||
|
||||
/* Perform an all reduce sum operation */
|
||||
void all_sum(Group group, const array& input, array& output);
|
||||
|
||||
/* Perform an all gather operation */
|
||||
void all_gather(Group group, const array& input, array& output);
|
||||
|
||||
/** Send an array to the dst rank */
|
||||
void send(Group group, const array& input, int dst);
|
||||
|
||||
/** Recv an array from the src rank */
|
||||
void recv(Group group, array& out, int src);
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#define LOAD_SYMBOL(symbol, variable) \
|
||||
@@ -47,6 +48,8 @@ struct MPIWrapper {
|
||||
LOAD_SYMBOL(MPI_Comm_free, comm_free);
|
||||
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
|
||||
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
||||
LOAD_SYMBOL(MPI_Send, send);
|
||||
LOAD_SYMBOL(MPI_Recv, recv);
|
||||
|
||||
// Objects
|
||||
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
|
||||
@@ -141,6 +144,8 @@ struct MPIWrapper {
|
||||
MPI_Comm);
|
||||
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
|
||||
int (*comm_free)(MPI_Comm*);
|
||||
int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
|
||||
int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
|
||||
|
||||
// Objects
|
||||
MPI_Comm comm_world_;
|
||||
@@ -284,6 +289,29 @@ void all_gather(Group group, const array& input_, array& output) {
|
||||
to_comm(group));
|
||||
}
|
||||
|
||||
void send(Group group, const array& input_, int dst) {
|
||||
array input = ensure_row_contiguous(input_);
|
||||
mpi().send(
|
||||
input.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
dst,
|
||||
0,
|
||||
to_comm(group));
|
||||
}
|
||||
|
||||
void recv(Group group, array& out, int src) {
|
||||
MPI_Status status;
|
||||
mpi().recv(
|
||||
out.data<void>(),
|
||||
out.size(),
|
||||
mpi().datatype(out),
|
||||
src,
|
||||
MPI_ANY_TAG,
|
||||
to_comm(group),
|
||||
&status);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
@@ -33,6 +34,8 @@ Stream communication_stream() {
|
||||
|
||||
void all_sum(Group group, const array& input, array& output) {}
|
||||
void all_gather(Group group, const array& input, array& output) {}
|
||||
void send(Group group, const array& input, int dst) {}
|
||||
void recv(Group group, array& out, int src) {}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
|
@@ -1,5 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
|
||||
@@ -17,7 +19,10 @@ Group to_group(std::optional<Group> group) {
|
||||
|
||||
} // namespace
|
||||
|
||||
array all_sum(const array& x, std::optional<Group> group_) {
|
||||
array all_sum(
|
||||
const array& x,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
@@ -27,11 +32,14 @@ array all_sum(const array& x, std::optional<Group> group_) {
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_shared<AllReduce>(group, AllReduce::Sum),
|
||||
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
|
||||
{x});
|
||||
}
|
||||
|
||||
array all_gather(const array& x, std::optional<Group> group_) {
|
||||
array all_gather(
|
||||
const array& x,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
@@ -47,8 +55,63 @@ array all_gather(const array& x, std::optional<Group> group_) {
|
||||
return array(
|
||||
std::move(result_shape),
|
||||
x.dtype(),
|
||||
std::make_shared<AllGather>(group),
|
||||
std::make_shared<AllGather>(to_stream(s), group),
|
||||
{x});
|
||||
}
|
||||
|
||||
array send(
|
||||
const array& x,
|
||||
int dst,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
throw std::invalid_argument("Cannot send to a singleton group");
|
||||
}
|
||||
|
||||
if (dst < 0 || dst >= group.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid destination=" << dst << " for a group of size "
|
||||
<< group.size();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
return array(
|
||||
{0}, int32, std::make_shared<Send>(to_stream(s), group, dst), {x});
|
||||
}
|
||||
|
||||
array recv(
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
int src,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
|
||||
if (group.size() == 1) {
|
||||
throw std::invalid_argument("Cannot recv from a singleton group");
|
||||
}
|
||||
|
||||
if (src < 0 || src >= group.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid source=" << src << " for a group of size " << group.size();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
return array(
|
||||
std::move(shape),
|
||||
std::move(dtype),
|
||||
std::make_shared<Recv>(to_stream(s), group, src),
|
||||
std::vector<array>{});
|
||||
}
|
||||
|
||||
array recv_like(
|
||||
const array& x,
|
||||
int src,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return recv(x.shape(), x.dtype(), src, group_, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@@ -5,10 +5,37 @@
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
array all_sum(const array& x, std::optional<Group> group = std::nullopt);
|
||||
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
|
||||
array all_sum(
|
||||
const array& x,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array all_gather(
|
||||
const array& x,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice S = {});
|
||||
|
||||
array send(
|
||||
const array& x,
|
||||
int dst,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array recv(
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
int src,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array recv_like(
|
||||
const array& x,
|
||||
int src,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@@ -3,7 +3,6 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -36,7 +35,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
||||
const std::vector<int>& axes) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {{all_sum(inputs[0], group())}, axes};
|
||||
return {{all_sum(inputs[0], group(), stream())}, axes};
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
@@ -48,7 +47,7 @@ std::vector<array> AllReduce::jvp(
|
||||
const std::vector<int>& argnums) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {all_sum(tangents[0], group())};
|
||||
return {all_sum(tangents[0], group(), stream())};
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
@@ -76,14 +75,14 @@ void AllGather::eval_cpu(
|
||||
std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
return {{all_gather(inputs[0], group())}, axes};
|
||||
return {{all_gather(inputs[0], group(), stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> AllGather::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
return {all_gather(tangents[0], group())};
|
||||
return {all_gather(tangents[0], group(), stream())};
|
||||
}
|
||||
|
||||
std::vector<array> AllGather::vjp(
|
||||
@@ -99,4 +98,29 @@ std::vector<array> AllGather::vjp(
|
||||
return {slice(cotangents[0], starts, stops)};
|
||||
}
|
||||
|
||||
void Send::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
distributed::detail::send(group(), inputs[0], dst_);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
return {{send(inputs[0], dst_, group(), stream())}, axes};
|
||||
}
|
||||
|
||||
void Recv::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
distributed::detail::recv(group(), outputs[0], src_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@@ -3,20 +3,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
class DistPrimitive : public Primitive {
|
||||
public:
|
||||
DistPrimitive(Group group)
|
||||
: Primitive(detail::communication_stream()), group_(group) {}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error(
|
||||
"Communication primitives cannot be run on the GPU");
|
||||
}
|
||||
DistPrimitive(Stream stream, Group group)
|
||||
: Primitive(stream), group_(group) {}
|
||||
|
||||
const Group& group() const {
|
||||
return group_;
|
||||
@@ -30,11 +25,13 @@ class AllReduce : public DistPrimitive {
|
||||
public:
|
||||
enum ReduceType { And, Or, Sum, Prod, Min, Max };
|
||||
|
||||
AllReduce(Group group, ReduceType reduce_type)
|
||||
: DistPrimitive(group), reduce_type_(reduce_type) {}
|
||||
AllReduce(Stream stream, Group group, ReduceType reduce_type)
|
||||
: DistPrimitive(stream, group), reduce_type_(reduce_type) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
@@ -77,10 +74,13 @@ class AllReduce : public DistPrimitive {
|
||||
|
||||
class AllGather : public DistPrimitive {
|
||||
public:
|
||||
AllGather(Group group) : DistPrimitive(group) {}
|
||||
AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
@@ -97,4 +97,39 @@ class AllGather : public DistPrimitive {
|
||||
DEFINE_PRINT(AllGather);
|
||||
};
|
||||
|
||||
class Send : public DistPrimitive {
|
||||
public:
|
||||
Send(Stream stream, Group group, int dst)
|
||||
: DistPrimitive(stream, group), dst_(dst) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
DEFINE_PRINT(Send);
|
||||
|
||||
private:
|
||||
int dst_;
|
||||
};
|
||||
|
||||
class Recv : public DistPrimitive {
|
||||
public:
|
||||
Recv(Stream stream, Group group, int src)
|
||||
: DistPrimitive(stream, group), src_(src) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(Recv);
|
||||
|
||||
private:
|
||||
int src_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
370
mlx/fast.cpp
370
mlx/fast.cpp
@@ -1,8 +1,10 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <regex>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -323,7 +325,7 @@ bool LayerNormVJP::is_equivalent(const Primitive& other) const {
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
std::vector<array> inputs,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
@@ -331,15 +333,23 @@ array rope(
|
||||
int offset,
|
||||
bool forward,
|
||||
StreamOrDevice s) {
|
||||
auto& x = inputs[0];
|
||||
if (x.ndim() < 3) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] Input must have at least 3 dimensions but got input with "
|
||||
<< x.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (inputs.size() == 2 &&
|
||||
(inputs[1].ndim() != 1 || inputs[1].shape(0) != dims / 2)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] freqs must be one dimensional with size " << dims / 2
|
||||
<< " but got shape " << inputs[1].shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, offset, forward, s](
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<array> inputs) {
|
||||
auto& shape = inputs[0].shape();
|
||||
int ndim = shape.size();
|
||||
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
|
||||
@@ -348,10 +358,20 @@ array rope(
|
||||
// Compute sines and cosines
|
||||
auto half_dims = dims / 2;
|
||||
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
|
||||
auto freqs = negative(arange(0, half_dims, t, s), s);
|
||||
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
|
||||
|
||||
auto default_inv_freqs = [&inputs, &s, &t, base, half_dims]() {
|
||||
return exp(
|
||||
multiply(
|
||||
arange(0, -half_dims, -1, t, s),
|
||||
array(std::log(base) / half_dims, t),
|
||||
s),
|
||||
s);
|
||||
};
|
||||
|
||||
auto inv_freqs =
|
||||
inputs.size() == 2 ? reciprocal(inputs[1], s) : default_inv_freqs();
|
||||
auto theta =
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(inv_freqs, 0, s), s);
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
|
||||
@@ -409,20 +429,39 @@ array rope(
|
||||
x.dtype(),
|
||||
std::make_shared<RoPE>(
|
||||
stream, fallback, dims, traditional, base, scale, offset, forward),
|
||||
{x});
|
||||
std::move(inputs));
|
||||
}
|
||||
return fallback({x})[0];
|
||||
return fallback(std::move(inputs))[0];
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
std::optional<float> base,
|
||||
float scale,
|
||||
int offset,
|
||||
const std::optional<array>& freqs /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return rope(x, dims, traditional, base, scale, offset, true, s);
|
||||
std::vector<array> inputs = {x};
|
||||
if (freqs) {
|
||||
inputs.push_back(astype(*freqs, float32, s));
|
||||
if (base) {
|
||||
throw std::invalid_argument(
|
||||
"[rope] Only one of base or freqs can have a value.");
|
||||
}
|
||||
} else if (!base) {
|
||||
throw std::invalid_argument("[rope] Neither base nor freqs has a value.");
|
||||
}
|
||||
return rope(
|
||||
std::move(inputs),
|
||||
dims,
|
||||
traditional,
|
||||
base.has_value() ? *base : 1.0,
|
||||
scale,
|
||||
offset,
|
||||
true,
|
||||
s);
|
||||
}
|
||||
|
||||
std::vector<array> RoPE::vjp(
|
||||
@@ -438,16 +477,27 @@ std::vector<array> RoPE::vjp(
|
||||
offset = offset_,
|
||||
forward = forward_,
|
||||
s](std::vector<array> inputs) {
|
||||
return std::vector<array>{
|
||||
rope(inputs[0], dims, traditional, base, scale, offset, !forward, s)};
|
||||
return std::vector<array>{rope(
|
||||
std::move(inputs),
|
||||
dims,
|
||||
traditional,
|
||||
base,
|
||||
scale,
|
||||
offset,
|
||||
!forward,
|
||||
s)};
|
||||
};
|
||||
|
||||
auto inputs = cotangents;
|
||||
if (primals.size() == 2) {
|
||||
inputs.push_back(primals[1]);
|
||||
}
|
||||
return {array(
|
||||
cotangents[0].shape(),
|
||||
cotangents[0].dtype(),
|
||||
std::make_shared<RoPE>(
|
||||
s, fallback, dims_, traditional_, base_, scale_, offset_, !forward_),
|
||||
cotangents)};
|
||||
std::move(inputs))};
|
||||
}
|
||||
|
||||
bool RoPE::is_equivalent(const Primitive& other) const {
|
||||
@@ -465,6 +515,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask,
|
||||
const std::optional<int>& memory_efficient_threshold,
|
||||
StreamOrDevice s) {
|
||||
for (const auto& tensor : {queries, keys, values}) {
|
||||
if (tensor.ndim() != 4) {
|
||||
@@ -535,6 +586,11 @@ array scaled_dot_product_attention(
|
||||
* * dtype is not fp32 or fp16
|
||||
*/
|
||||
|
||||
int threshold = 1e6;
|
||||
if (memory_efficient_threshold.has_value()) {
|
||||
threshold = std::max(1, memory_efficient_threshold.value());
|
||||
}
|
||||
|
||||
bool needs_mask = mask.has_value();
|
||||
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
|
||||
const std::vector<array>& inputs) {
|
||||
@@ -581,9 +637,10 @@ array scaled_dot_product_attention(
|
||||
bool implementation_supports_use_case =
|
||||
supports_sdpa || supports_full_self_attention;
|
||||
|
||||
// disabling full self attention until perf is tuned;
|
||||
// likewise for sdpa
|
||||
implementation_supports_use_case &= false;
|
||||
// sdpa gpu shader is disabled except for memory efficient opt-in
|
||||
const int seq_for_threshold = queries.shape(2);
|
||||
bool use_memory_efficient_impl = seq_for_threshold >= threshold;
|
||||
implementation_supports_use_case &= use_memory_efficient_impl;
|
||||
|
||||
if (implementation_supports_use_case) {
|
||||
auto out_shape =
|
||||
@@ -859,4 +916,285 @@ array affine_dequantize(
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
void validate_output_shapes(
|
||||
std::map<std::string, std::vector<int>> output_shapes,
|
||||
std::map<std::string, Dtype> output_dtypes) {
|
||||
// Make sure output shapes and dtypes have the same keys
|
||||
bool validated = true;
|
||||
if (output_shapes.size() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
if (output_shapes.size() != output_dtypes.size()) {
|
||||
validated = false;
|
||||
} else {
|
||||
for (const auto& kv : output_shapes) {
|
||||
if (output_dtypes.find(kv.first) == output_dtypes.end()) {
|
||||
validated = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!validated) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] `output_shapes` and `output_dtypes` must have the same keys.");
|
||||
}
|
||||
}
|
||||
|
||||
void write_signature(
|
||||
std::string func_name,
|
||||
std::string& source,
|
||||
std::map<std::string, array>& inputs,
|
||||
std::map<std::string, std::vector<int>>& output_shapes,
|
||||
std::map<std::string, Dtype>& output_dtypes,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs,
|
||||
std::ostringstream& kernel_source) {
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
if (template_args && template_args.value().size() > 0) {
|
||||
kernel_source << "template <";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args.value()) {
|
||||
std::string param_type;
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
param_type = "int";
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
param_type = "bool";
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
param_type = "typename";
|
||||
}
|
||||
if (i > 0) {
|
||||
kernel_source << ", ";
|
||||
}
|
||||
kernel_source << param_type << " " << name;
|
||||
i++;
|
||||
}
|
||||
kernel_source << ">" << std::endl;
|
||||
}
|
||||
kernel_source << "[[kernel]] void " << func_name << "(" << std::endl;
|
||||
|
||||
// Metal attributes are automatically added to the arguments if present
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
{"dispatch_quadgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_simdgroups_per_threadgroup", "uint"},
|
||||
{"dispatch_threads_per_threadgroup", "uint3"},
|
||||
{"grid_origin", "uint3"},
|
||||
{"grid_size", "uint3"},
|
||||
{"quadgroup_index_in_threadgroup", "uint"},
|
||||
{"quadgroups_per_threadgroup", "uint"},
|
||||
{"simdgroup_index_in_threadgroup", "uint"},
|
||||
{"simdgroups_per_threadgroup", "uint"},
|
||||
{"thread_execution_width", "uint"},
|
||||
{"thread_index_in_quadgroup", "uint"},
|
||||
{"thread_index_in_simdgroup", "uint"},
|
||||
{"thread_index_in_threadgroup", "uint"},
|
||||
{"thread_position_in_grid", "uint3"},
|
||||
{"thread_position_in_threadgroup", "uint3"},
|
||||
{"threadgroup_position_in_grid", "uint3"},
|
||||
{"threadgroups_per_grid", "uint3"},
|
||||
{"threads_per_grid", "uint3"},
|
||||
{"threads_per_simdgroup", "uint"},
|
||||
{"thread_per_threadgroup", "uint3"},
|
||||
};
|
||||
std::vector<std::pair<std::string, std::string>> attrs;
|
||||
for (const auto& [attr, dtype] : metal_attributes) {
|
||||
if (source.find(attr) != std::string::npos) {
|
||||
attrs.push_back({attr, dtype});
|
||||
}
|
||||
}
|
||||
|
||||
int index = 0;
|
||||
constexpr int max_constant_array_size = 8;
|
||||
// Add inputs
|
||||
for (const auto& [name, arr] : inputs) {
|
||||
auto dtype = get_type_string(arr.dtype());
|
||||
bool is_constant =
|
||||
arr.is_available() && arr.size() < max_constant_array_size;
|
||||
std::string location = is_constant ? "constant" : "device";
|
||||
std::string ref = arr.ndim() == 0 ? "&" : "*";
|
||||
kernel_source << " const " << location << " " << dtype << ref << " "
|
||||
<< name << " [[buffer(" << index << ")]]," << std::endl;
|
||||
index++;
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
CustomKernelShapeInfo shape_info;
|
||||
if (arr.ndim() > 0) {
|
||||
if (source.find(name + "_shape") != std::string::npos) {
|
||||
kernel_source << " const constant int* " << name << "_shape [[buffer("
|
||||
<< index << ")]]," << std::endl;
|
||||
shape_info.shape = true;
|
||||
index++;
|
||||
}
|
||||
if (source.find(name + "_strides") != std::string::npos) {
|
||||
kernel_source << " const constant size_t* " << name
|
||||
<< "_strides [[buffer(" << index << ")]]," << std::endl;
|
||||
shape_info.strides = true;
|
||||
index++;
|
||||
}
|
||||
if (source.find(name + "_ndim") != std::string::npos) {
|
||||
kernel_source << " const constant int& " << name << "_ndim [[buffer("
|
||||
<< index << ")]]," << std::endl;
|
||||
shape_info.ndim = true;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
// Add outputs
|
||||
for (const auto& [name, dtype] : output_dtypes) {
|
||||
kernel_source << " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
kernel_source << "atomic<" << type_string << ">";
|
||||
} else {
|
||||
kernel_source << type_string;
|
||||
}
|
||||
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
|
||||
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
|
||||
kernel_source << "," << std::endl;
|
||||
} else {
|
||||
kernel_source << ") {" << std::endl;
|
||||
}
|
||||
index++;
|
||||
}
|
||||
// Add metal attributes e.g. `threadgroup_index_in_grid`
|
||||
index = 0;
|
||||
for (const auto& [attr, dtype] : attrs) {
|
||||
kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]";
|
||||
if (index < attrs.size() - 1) {
|
||||
kernel_source << "," << std::endl;
|
||||
} else {
|
||||
kernel_source << ") {" << std::endl;
|
||||
}
|
||||
index++;
|
||||
}
|
||||
kernel_source << source << std::endl;
|
||||
kernel_source << "}" << std::endl;
|
||||
}
|
||||
|
||||
std::string write_template(std::map<std::string, TemplateArg>& template_args) {
|
||||
std::ostringstream template_def;
|
||||
template_def << "<";
|
||||
int i = 0;
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (i > 0) {
|
||||
template_def << ", ";
|
||||
}
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
template_def << std::get<int>(arg);
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
template_def << std::get<bool>(arg);
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
template_def << get_type_string(std::get<Dtype>(arg));
|
||||
}
|
||||
i++;
|
||||
}
|
||||
template_def << ">";
|
||||
return template_def.str();
|
||||
}
|
||||
|
||||
std::map<std::string, array> MetalKernel::operator()(
|
||||
std::map<std::string, array>& inputs,
|
||||
std::map<std::string, std::vector<int>> output_shapes,
|
||||
std::map<std::string, Dtype> output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
std::optional<float> init_value,
|
||||
bool verbose,
|
||||
StreamOrDevice s_) {
|
||||
validate_output_shapes(output_shapes, output_dtypes);
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] MetalKernel only works on GPU.");
|
||||
}
|
||||
|
||||
std::ostringstream func_name;
|
||||
|
||||
std::string template_def = "";
|
||||
bool needs_template = template_args && template_args.value().size() > 0;
|
||||
std::string hash_key = "";
|
||||
if (needs_template) {
|
||||
std::regex disallowed_chars("\\<|\\>|(, )");
|
||||
template_def = write_template(template_args.value());
|
||||
hash_key = std::regex_replace(template_def, disallowed_chars, "_");
|
||||
hash_key.pop_back();
|
||||
}
|
||||
|
||||
func_name << "custom_kernel_" << name_ << hash_key;
|
||||
std::string kernel_name = func_name.str();
|
||||
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << header_ << std::endl;
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
write_signature(
|
||||
func_name.str(),
|
||||
source_,
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
shape_infos,
|
||||
atomic_outputs_,
|
||||
kernel_source);
|
||||
|
||||
if (needs_template) {
|
||||
template_def = func_name.str() + template_def;
|
||||
kernel_source << std::endl
|
||||
<< "template [[host_name(\"" << kernel_name
|
||||
<< "\")]] [[kernel]] decltype(" << template_def << ") "
|
||||
<< template_def << ";" << std::endl;
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << name_ << "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source.str() << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
std::vector<array> in_arrs;
|
||||
for (const auto& kv : inputs) {
|
||||
in_arrs.push_back(kv.second);
|
||||
}
|
||||
|
||||
std::vector<std::string> out_keys;
|
||||
std::vector<std::vector<int>> out_shapes;
|
||||
for (const auto& [name, shape] : output_shapes) {
|
||||
out_keys.push_back(name);
|
||||
out_shapes.push_back(shape);
|
||||
}
|
||||
|
||||
std::vector<Dtype> out_dtypes;
|
||||
for (const auto& kv : output_dtypes) {
|
||||
out_dtypes.push_back(kv.second);
|
||||
}
|
||||
|
||||
std::map<std::string, array> outputs;
|
||||
auto outputs_vec = array::make_arrays(
|
||||
out_shapes,
|
||||
out_dtypes,
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
kernel_name,
|
||||
kernel_source.str(),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous_,
|
||||
init_value),
|
||||
in_arrs);
|
||||
|
||||
int i = 0;
|
||||
for (const auto& key : out_keys) {
|
||||
outputs.insert({key, outputs_vec[i]});
|
||||
i++;
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
40
mlx/fast.h
40
mlx/fast.h
@@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/utils.h"
|
||||
@@ -25,9 +26,10 @@ array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
std::optional<float> base,
|
||||
float scale,
|
||||
int offset,
|
||||
const std::optional<array>& freqs = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Computes: O = softmax(Q @ K.T) @ V **/
|
||||
@@ -37,6 +39,7 @@ array scaled_dot_product_attention(
|
||||
const array& values,
|
||||
const float scale,
|
||||
const std::optional<array>& mask = std::nullopt,
|
||||
const std::optional<int>& memory_efficient_threshold = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::tuple<array, array, array> affine_quantize(
|
||||
@@ -61,4 +64,39 @@ array affine_dequantize(
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
typedef std::variant<int, bool, Dtype> TemplateArg;
|
||||
|
||||
class MetalKernel {
|
||||
public:
|
||||
MetalKernel(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
const std::string& header = "",
|
||||
bool ensure_row_contiguous = true,
|
||||
bool atomic_outputs = false)
|
||||
: name_(name),
|
||||
source_(source),
|
||||
header_(header),
|
||||
ensure_row_contiguous_(ensure_row_contiguous),
|
||||
atomic_outputs_(atomic_outputs) {}
|
||||
|
||||
std::map<std::string, array> operator()(
|
||||
std::map<std::string, array>& inputs,
|
||||
std::map<std::string, std::vector<int>> output_shapes,
|
||||
std::map<std::string, Dtype> output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args =
|
||||
std::nullopt,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::string source_;
|
||||
std::string header_;
|
||||
bool ensure_row_contiguous_;
|
||||
bool atomic_outputs_;
|
||||
};
|
||||
} // namespace mlx::core::fast
|
||||
|
@@ -1,5 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
@@ -242,4 +244,50 @@ class AffineQuantize : public Custom {
|
||||
bool dequantize_;
|
||||
};
|
||||
|
||||
struct CustomKernelShapeInfo {
|
||||
bool shape = false;
|
||||
bool strides = false;
|
||||
bool ndim = false;
|
||||
};
|
||||
|
||||
class CustomKernel : public Primitive {
|
||||
public:
|
||||
CustomKernel(
|
||||
Stream stream,
|
||||
std::string name,
|
||||
std::string source,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::vector<CustomKernelShapeInfo> shape_infos,
|
||||
bool ensure_row_contiguous,
|
||||
std::optional<float> init_value)
|
||||
: Primitive(stream),
|
||||
source_(source),
|
||||
name_(name),
|
||||
grid_(grid),
|
||||
threadgroup_(threadgroup),
|
||||
shape_infos_(shape_infos),
|
||||
ensure_row_contiguous_(ensure_row_contiguous),
|
||||
init_value_(init_value) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("Custom Metal kernels only run on GPU.");
|
||||
}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(CustomKernel);
|
||||
|
||||
private:
|
||||
std::string source_;
|
||||
std::string name_;
|
||||
std::tuple<int, int, int> grid_;
|
||||
std::tuple<int, int, int> threadgroup_;
|
||||
std::vector<CustomKernelShapeInfo> shape_infos_;
|
||||
bool ensure_row_contiguous_;
|
||||
std::optional<float> init_value_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/io/gguf.h"
|
||||
|
@@ -298,7 +298,65 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
|
||||
|
||||
/** Load array from file in .npy format */
|
||||
array load(std::string file, StreamOrDevice s) {
|
||||
return load(std::make_shared<io::FileReader>(std::move(file)), s);
|
||||
return load(std::make_shared<io::ParallelFileReader>(std::move(file)), s);
|
||||
}
|
||||
|
||||
namespace io {
|
||||
|
||||
ThreadPool& thread_pool() {
|
||||
static ThreadPool pool_{4};
|
||||
return pool_;
|
||||
}
|
||||
|
||||
ThreadPool ParallelFileReader::thread_pool_{4};
|
||||
|
||||
void ParallelFileReader::read(char* data, size_t n) {
|
||||
while (n != 0) {
|
||||
auto m = ::read(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
|
||||
if (m <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[read] Unable to read " << n << " bytes from file.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
data += m;
|
||||
n -= m;
|
||||
}
|
||||
}
|
||||
|
||||
void ParallelFileReader::read(char* data, size_t n, size_t offset) {
|
||||
auto readfn = [fd = fd_](size_t offset, size_t size, char* buffer) -> bool {
|
||||
while (size != 0) {
|
||||
auto m = pread(fd, buffer, size, offset);
|
||||
if (m <= 0) {
|
||||
return false;
|
||||
}
|
||||
buffer += m;
|
||||
size -= m;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
std::vector<std::future<bool>> futs;
|
||||
while (n != 0) {
|
||||
if (n < batch_size_) {
|
||||
if (!readfn(offset, n, data)) {
|
||||
throw std::runtime_error("[read] Unable to read from file.");
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
size_t m = batch_size_;
|
||||
futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data));
|
||||
data += m;
|
||||
n -= m;
|
||||
offset += m;
|
||||
}
|
||||
}
|
||||
for (auto& f : futs) {
|
||||
if (!f.get()) {
|
||||
throw std::runtime_error("[read] Unable to read from file.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace io
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -2,14 +2,20 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <istream>
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/io/threadpool.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace io {
|
||||
|
||||
ThreadPool& thread_pool();
|
||||
|
||||
class Reader {
|
||||
public:
|
||||
virtual bool is_open() const = 0;
|
||||
@@ -19,7 +25,9 @@ class Reader {
|
||||
int64_t off,
|
||||
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
||||
virtual void read(char* data, size_t n) = 0;
|
||||
virtual void read(char* data, size_t n, size_t offset) = 0;
|
||||
virtual std::string label() const = 0;
|
||||
virtual ~Reader() = default;
|
||||
};
|
||||
|
||||
class Writer {
|
||||
@@ -32,73 +40,93 @@ class Writer {
|
||||
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
||||
virtual void write(const char* data, size_t n) = 0;
|
||||
virtual std::string label() const = 0;
|
||||
virtual ~Writer() = default;
|
||||
};
|
||||
|
||||
class FileReader : public Reader {
|
||||
class ParallelFileReader : public Reader {
|
||||
public:
|
||||
explicit FileReader(std::ifstream is)
|
||||
: is_(std::move(is)), label_("stream") {}
|
||||
explicit FileReader(std::string file_path)
|
||||
: is_(std::ifstream(file_path, std::ios::binary)),
|
||||
label_(std::move(file_path)) {}
|
||||
explicit ParallelFileReader(std::string file_path)
|
||||
: fd_(open(file_path.c_str(), O_RDONLY)), label_(std::move(file_path)) {}
|
||||
|
||||
~ParallelFileReader() override {
|
||||
close(fd_);
|
||||
}
|
||||
|
||||
bool is_open() const override {
|
||||
return is_.is_open();
|
||||
return fd_ > 0;
|
||||
}
|
||||
|
||||
bool good() const override {
|
||||
return is_.good();
|
||||
return is_open();
|
||||
}
|
||||
|
||||
size_t tell() override {
|
||||
return is_.tellg();
|
||||
return lseek(fd_, 0, SEEK_CUR);
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
is_.seekg(off, way);
|
||||
void seek(int64_t, std::ios_base::seekdir = std::ios_base::beg) override {
|
||||
throw std::runtime_error("[ParallelFileReader::seek] Not allowed");
|
||||
}
|
||||
|
||||
void read(char* data, size_t n) override {
|
||||
is_.read(data, n);
|
||||
}
|
||||
// Warning: do not use this function from multiple threads as
|
||||
// it advances the file descriptor
|
||||
void read(char* data, size_t n) override;
|
||||
|
||||
void read(char* data, size_t n, size_t offset) override;
|
||||
|
||||
std::string label() const override {
|
||||
return "file " + label_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::ifstream is_;
|
||||
static constexpr size_t batch_size_ = 1 << 25;
|
||||
static ThreadPool thread_pool_;
|
||||
int fd_;
|
||||
std::string label_;
|
||||
};
|
||||
|
||||
class FileWriter : public Writer {
|
||||
public:
|
||||
explicit FileWriter(std::ofstream os)
|
||||
: os_(std::move(os)), label_("stream") {}
|
||||
explicit FileWriter(std::string file_path)
|
||||
: os_(std::ofstream(file_path, std::ios::binary)),
|
||||
: fd_(open(file_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC, 0644)),
|
||||
label_(std::move(file_path)) {}
|
||||
|
||||
~FileWriter() override {
|
||||
close(fd_);
|
||||
}
|
||||
|
||||
bool is_open() const override {
|
||||
return os_.is_open();
|
||||
return fd_ >= 0;
|
||||
}
|
||||
|
||||
bool good() const override {
|
||||
return os_.good();
|
||||
return is_open();
|
||||
}
|
||||
|
||||
size_t tell() override {
|
||||
return os_.tellp();
|
||||
return lseek(fd_, 0, SEEK_CUR);
|
||||
}
|
||||
|
||||
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
||||
override {
|
||||
os_.seekp(off, way);
|
||||
if (way == std::ios_base::beg) {
|
||||
lseek(fd_, off, 0);
|
||||
} else {
|
||||
lseek(fd_, off, SEEK_CUR);
|
||||
}
|
||||
}
|
||||
|
||||
void write(const char* data, size_t n) override {
|
||||
os_.write(data, n);
|
||||
while (n != 0) {
|
||||
auto m = ::write(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
|
||||
if (m <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[write] Unable to write " << n << " bytes to file.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
data += m;
|
||||
n -= m;
|
||||
}
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
@@ -106,7 +134,7 @@ class FileWriter : public Writer {
|
||||
}
|
||||
|
||||
private:
|
||||
std::ofstream os_;
|
||||
int fd_;
|
||||
std::string label_;
|
||||
};
|
||||
|
||||
|
@@ -147,7 +147,7 @@ SafetensorsLoad load_safetensors(
|
||||
}
|
||||
|
||||
SafetensorsLoad load_safetensors(const std::string& file, StreamOrDevice s) {
|
||||
return load_safetensors(std::make_shared<io::FileReader>(file), s);
|
||||
return load_safetensors(std::make_shared<io::ParallelFileReader>(file), s);
|
||||
}
|
||||
|
||||
void save_safetensors(
|
||||
|
103
mlx/io/threadpool.h
Normal file
103
mlx/io/threadpool.h
Normal file
@@ -0,0 +1,103 @@
|
||||
// This code was modified from https://github.com/progschj/ThreadPool
|
||||
// The original License is copied below:
|
||||
//
|
||||
// Copyright (c) 2012 Jakob Progsch, Václav Zeman
|
||||
// This software is provided 'as-is', without any express or implied
|
||||
// warranty. In no event will the authors be held liable for any damages
|
||||
// arising from the use of this software.
|
||||
//
|
||||
// Permission is granted to anyone to use this software for any purpose,
|
||||
// including commercial applications, and to alter it and redistribute it
|
||||
// freely, subject to the following restrictions:
|
||||
//
|
||||
// 1. The origin of this software must not be misrepresented; you must not
|
||||
// claim that you wrote the original software. If you use this software
|
||||
// in a product, an acknowledgment in the product documentation would be
|
||||
// appreciated but is not required.
|
||||
//
|
||||
// 2. Altered source versions must be plainly marked as such, and must not be
|
||||
// misrepresented as being the original software.
|
||||
//
|
||||
// 3. This notice may not be removed or altered from any source
|
||||
// distribution.
|
||||
#pragma once
|
||||
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
ThreadPool(size_t);
|
||||
template <class F, class... Args>
|
||||
auto enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of_t<F(Args...)>>;
|
||||
~ThreadPool();
|
||||
|
||||
private:
|
||||
std::vector<std::thread> workers;
|
||||
std::queue<std::function<void()>> tasks;
|
||||
std::mutex queue_mutex;
|
||||
std::condition_variable condition;
|
||||
bool stop;
|
||||
};
|
||||
|
||||
inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
|
||||
for (size_t i = 0; i < threads; ++i)
|
||||
workers.emplace_back([this] {
|
||||
for (;;) {
|
||||
std::function<void()> task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(
|
||||
lock, [this] { return this->stop || !this->tasks.empty(); });
|
||||
if (this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
|
||||
task();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <class F, class... Args>
|
||||
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of_t<F(Args...)>> {
|
||||
using return_type = typename std::result_of_t<F(Args...)>;
|
||||
|
||||
auto task = std::make_shared<std::packaged_task<return_type()>>(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
if (stop) {
|
||||
throw std::runtime_error(
|
||||
"[ThreadPool::enqueue] Not allowed on stopped ThreadPool");
|
||||
}
|
||||
|
||||
tasks.emplace([task]() { (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
inline ThreadPool::~ThreadPool() {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for (std::thread& worker : workers)
|
||||
worker.join();
|
||||
}
|
@@ -306,6 +306,49 @@ array cholesky(
|
||||
{a});
|
||||
}
|
||||
|
||||
array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::pinv] Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::pinv] Arrays must have >= 2 dimensions. Received array "
|
||||
<< "with " << a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int m = a.shape(-2);
|
||||
int n = a.shape(-1);
|
||||
int k = std::min(m, n);
|
||||
auto outs = linalg::svd(a, s);
|
||||
array U = outs[0];
|
||||
array S = outs[1];
|
||||
array V = outs[2];
|
||||
|
||||
std::vector<int> starts(a.ndim(), 0);
|
||||
std::vector<int> ends = a.shape();
|
||||
int i = a.ndim() - 2;
|
||||
int j = a.ndim() - 1;
|
||||
|
||||
// Prepare U
|
||||
ends[i] = m;
|
||||
ends[j] = k;
|
||||
U = swapaxes(slice(U, starts, ends, s), -1, -2, s);
|
||||
|
||||
// Prepare V
|
||||
ends[i] = k;
|
||||
ends[j] = n;
|
||||
V = swapaxes(slice(V, starts, ends, s), -1, -2, s);
|
||||
|
||||
// Prepare S
|
||||
S = expand_dims(S, -2, s);
|
||||
|
||||
return matmul(divide(V, S, s), U);
|
||||
}
|
||||
|
||||
array cholesky_inv(
|
||||
const array& L,
|
||||
bool upper /* = false */,
|
||||
|
@@ -70,6 +70,8 @@ array tri_inv(const array& a, bool upper = false, StreamOrDevice s = {});
|
||||
|
||||
array cholesky(const array& a, bool upper = false, StreamOrDevice s = {});
|
||||
|
||||
array pinv(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
166
mlx/ops.cpp
166
mlx/ops.cpp
@@ -16,9 +16,11 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
||||
std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, bool>
|
||||
compute_reduce_shape(
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& shape) {
|
||||
bool is_noop = true;
|
||||
std::set<int> axes_set;
|
||||
auto ndim = shape.size();
|
||||
for (auto ax : axes) {
|
||||
@@ -35,15 +37,18 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
|
||||
throw std::invalid_argument("Duplicate axes detected in reduction.");
|
||||
}
|
||||
std::vector<int> out_shape;
|
||||
std::vector<int> squeezed_shape;
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
if (axes_set.count(i) == 0) {
|
||||
out_shape.push_back(shape[i]);
|
||||
squeezed_shape.push_back(shape[i]);
|
||||
} else {
|
||||
out_shape.push_back(1);
|
||||
}
|
||||
is_noop &= (out_shape.back() == shape[i]);
|
||||
}
|
||||
std::vector<int> sorted_axes(axes_set.begin(), axes_set.end());
|
||||
return {out_shape, sorted_axes};
|
||||
return {out_shape, sorted_axes, squeezed_shape, is_noop};
|
||||
}
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
@@ -1378,6 +1383,10 @@ array isinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return logical_or(isposinf(a, s), isneginf(a, s), s);
|
||||
}
|
||||
|
||||
array isfinite(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return logical_not(logical_or(isinf(a, s), isnan(a, s), s), s);
|
||||
}
|
||||
|
||||
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (issubdtype(a.dtype(), integer) || a.dtype() == bool_) {
|
||||
return full(a.shape(), false, bool_, s);
|
||||
@@ -1498,17 +1507,17 @@ array all(
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (axes.empty()) {
|
||||
return astype(a, bool_, s);
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
bool_,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? astype(a, bool_, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
bool_,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1532,17 +1541,17 @@ array any(
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (axes.empty()) {
|
||||
return astype(a, bool_, s);
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
bool_,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? astype(a, bool_, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
bool_,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1569,15 +1578,18 @@ array sum(
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
|
||||
auto out = array(
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
||||
{a});
|
||||
auto out = (is_noop)
|
||||
? astype(a, out_type, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1711,14 +1723,17 @@ array prod(
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? a
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1745,17 +1760,17 @@ array max(
|
||||
if (a.size() == 0) {
|
||||
throw std::invalid_argument("[max] Cannot max reduce zero size array.");
|
||||
}
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? a
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1785,14 +1800,17 @@ array min(
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape(axes, a.shape());
|
||||
auto out = (is_noop)
|
||||
? a
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
a.dtype(),
|
||||
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1825,15 +1843,18 @@ array argmin(
|
||||
throw std::invalid_argument(
|
||||
"[argmin] Cannot argmin reduce zero size array.");
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
uint32,
|
||||
std::make_shared<ArgReduce>(
|
||||
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape({axis}, a.shape());
|
||||
auto out = (is_noop)
|
||||
? zeros(out_shape, uint32, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
uint32,
|
||||
std::make_shared<ArgReduce>(
|
||||
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -1858,15 +1879,18 @@ array argmax(
|
||||
throw std::invalid_argument(
|
||||
"[argmax] Cannot argmax reduce zero size array.");
|
||||
}
|
||||
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
|
||||
auto out = array(
|
||||
out_shape,
|
||||
uint32,
|
||||
std::make_shared<ArgReduce>(
|
||||
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
||||
{a});
|
||||
auto [out_shape, sorted_axes, squeezed_shape, is_noop] =
|
||||
compute_reduce_shape({axis}, a.shape());
|
||||
auto out = (is_noop)
|
||||
? zeros(out_shape, uint32, s)
|
||||
: array(
|
||||
std::move(out_shape),
|
||||
uint32,
|
||||
std::make_shared<ArgReduce>(
|
||||
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
|
||||
{a});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, sorted_axes, s);
|
||||
out = reshape(out, std::move(squeezed_shape), s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -2903,6 +2927,10 @@ array softmax(
|
||||
const std::vector<int>& axes,
|
||||
bool precise /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (a.size() == 0) {
|
||||
return a;
|
||||
}
|
||||
|
||||
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return array(
|
||||
|
@@ -399,6 +399,8 @@ array isnan(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array isinf(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array isfinite(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array isposinf(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array isneginf(const array& a, StreamOrDevice s = {});
|
||||
|
@@ -186,10 +186,35 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
|
||||
void async_eval(std::vector<array> outputs) {
|
||||
if (outputs.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (std::none_of(outputs.begin(), outputs.end(), [](array& x) {
|
||||
return x.status() == array::Status::unscheduled;
|
||||
})) {
|
||||
return;
|
||||
}
|
||||
|
||||
eval_impl(std::move(outputs), true);
|
||||
}
|
||||
|
||||
void eval(std::vector<array> outputs) {
|
||||
if (outputs.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (std::none_of(outputs.begin(), outputs.end(), [](array& x) {
|
||||
return x.status() == array::Status::unscheduled;
|
||||
})) {
|
||||
for (auto& x : outputs) {
|
||||
if (!x.is_available()) {
|
||||
x.event().wait();
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
eval_impl(std::move(outputs), false).event().wait();
|
||||
}
|
||||
|
||||
|
@@ -4,10 +4,10 @@
|
||||
|
||||
#include <variant>
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "dtype.h"
|
||||
#include "stream.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
|
||||
"nanobind==2.1.0",
|
||||
"cmake>=3.24",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@@ -234,7 +234,7 @@ def glorot_uniform(
|
||||
|
||||
def he_normal(
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> Callable[[mx.array, str, float], mx.array]:
|
||||
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
|
||||
r"""Build a He normal initializer.
|
||||
|
||||
This initializer samples from a normal distribution with a standard
|
||||
@@ -292,7 +292,7 @@ def he_normal(
|
||||
|
||||
def he_uniform(
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> Callable[[mx.array, str, float], mx.array]:
|
||||
) -> Callable[[mx.array, Literal["fan_in", "fan_out"], float], mx.array]:
|
||||
r"""A He uniform (Kaiming uniform) initializer.
|
||||
|
||||
This initializer samples from a uniform distribution with a range
|
||||
|
@@ -72,6 +72,8 @@ from mlx.nn.layers.recurrent import GRU, LSTM, RNN
|
||||
from mlx.nn.layers.transformer import (
|
||||
MultiHeadAttention,
|
||||
Transformer,
|
||||
TransformerDecoder,
|
||||
TransformerDecoderLayer,
|
||||
TransformerEncoder,
|
||||
TransformerEncoderLayer,
|
||||
)
|
||||
|
@@ -1,5 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
@@ -7,42 +9,6 @@ import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
||||
if is_leaf_fn(model, value_key, value):
|
||||
return map_fn(value)
|
||||
|
||||
elif isinstance(value, Module):
|
||||
return {
|
||||
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||
for k, v in value.items()
|
||||
if filter_fn(value, k, v)
|
||||
}
|
||||
|
||||
elif isinstance(value, dict):
|
||||
nd = {}
|
||||
for k, v in value.items():
|
||||
tk = f"{value_key}.{k}"
|
||||
nd[k] = (
|
||||
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, v)
|
||||
else {}
|
||||
)
|
||||
return nd
|
||||
|
||||
elif isinstance(value, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(value):
|
||||
tk = f"{value_key}.{i}"
|
||||
nl.append(
|
||||
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, vi)
|
||||
else {}
|
||||
)
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
|
||||
class Module(dict):
|
||||
"""Base class for building neural networks with MLX.
|
||||
|
||||
@@ -151,7 +117,7 @@ class Module(dict):
|
||||
self,
|
||||
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
||||
strict: bool = True,
|
||||
) -> "Module":
|
||||
) -> Module:
|
||||
"""
|
||||
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
|
||||
|
||||
@@ -266,9 +232,9 @@ class Module(dict):
|
||||
|
||||
def filter_and_map(
|
||||
self,
|
||||
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
|
||||
filter_fn: Callable[[Module, str, Any], bool],
|
||||
map_fn: Optional[Callable] = None,
|
||||
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
is_leaf_fn: Optional[Callable[[Module, str, Any], bool]] = None,
|
||||
):
|
||||
"""Recursively filter the contents of the module using ``filter_fn``,
|
||||
namely only select keys and values where ``filter_fn`` returns true.
|
||||
@@ -323,7 +289,7 @@ class Module(dict):
|
||||
|
||||
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
||||
|
||||
def update(self, parameters: dict) -> "Module":
|
||||
def update(self, parameters: dict) -> Module:
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
@@ -371,8 +337,8 @@ class Module(dict):
|
||||
def apply(
|
||||
self,
|
||||
map_fn: Callable[[mx.array], mx.array],
|
||||
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
||||
) -> "Module":
|
||||
filter_fn: Optional[Callable[[Module, str, Any], bool]] = None,
|
||||
) -> Module:
|
||||
"""Map all the parameters using the provided ``map_fn`` and immediately
|
||||
update the module with the mapped parameters.
|
||||
|
||||
@@ -391,7 +357,7 @@ class Module(dict):
|
||||
self.update(self.filter_and_map(filter_fn, map_fn))
|
||||
return self
|
||||
|
||||
def update_modules(self, modules: dict) -> "Module":
|
||||
def update_modules(self, modules: dict) -> Module:
|
||||
"""Replace the child modules of this :class:`Module` instance with the
|
||||
provided ones in the dict of dicts and lists.
|
||||
|
||||
@@ -432,9 +398,7 @@ class Module(dict):
|
||||
apply(self, modules)
|
||||
return self
|
||||
|
||||
def apply_to_modules(
|
||||
self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]
|
||||
) -> "Module":
|
||||
def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
|
||||
"""Apply a function to all the modules in this instance (including this
|
||||
instance).
|
||||
|
||||
@@ -489,7 +453,7 @@ class Module(dict):
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
) -> "Module":
|
||||
) -> Module:
|
||||
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
||||
computing gradients for it.
|
||||
|
||||
@@ -544,7 +508,7 @@ class Module(dict):
|
||||
recurse: bool = True,
|
||||
keys: Optional[Union[str, List[str]]] = None,
|
||||
strict: bool = False,
|
||||
) -> "Module":
|
||||
) -> Module:
|
||||
"""Unfreeze the Module's parameters or some of them.
|
||||
|
||||
This function is idempotent ie unfreezing a model that is not frozen is
|
||||
@@ -588,7 +552,7 @@ class Module(dict):
|
||||
_unfreeze_impl("", self)
|
||||
return self
|
||||
|
||||
def train(self, mode: bool = True) -> "Module":
|
||||
def train(self, mode: bool = True) -> Module:
|
||||
"""Set the model in or out of training mode.
|
||||
|
||||
Training mode only applies to certain layers. For example
|
||||
@@ -608,7 +572,7 @@ class Module(dict):
|
||||
self.apply_to_modules(_set_train)
|
||||
return self
|
||||
|
||||
def eval(self) -> "Module":
|
||||
def eval(self) -> Module:
|
||||
"""Set the model to evaluation mode.
|
||||
|
||||
See :func:`train`.
|
||||
@@ -637,3 +601,39 @@ class Module(dict):
|
||||
return True
|
||||
|
||||
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
|
||||
|
||||
|
||||
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
||||
if is_leaf_fn(model, value_key, value):
|
||||
return map_fn(value)
|
||||
|
||||
elif isinstance(value, Module):
|
||||
return {
|
||||
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||
for k, v in value.items()
|
||||
if filter_fn(value, k, v)
|
||||
}
|
||||
|
||||
elif isinstance(value, dict):
|
||||
nd = {}
|
||||
for k, v in value.items():
|
||||
tk = f"{value_key}.{k}"
|
||||
nd[k] = (
|
||||
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, v)
|
||||
else {}
|
||||
)
|
||||
return nd
|
||||
|
||||
elif isinstance(value, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(value):
|
||||
tk = f"{value_key}.{i}"
|
||||
nl.append(
|
||||
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, vi)
|
||||
else {}
|
||||
)
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
@@ -32,7 +32,7 @@ class Dropout(Module):
|
||||
|
||||
mask = mx.random.bernoulli(self._p_1, x.shape)
|
||||
|
||||
return (1 / self._p_1) * mask * x
|
||||
return (mask * x) * (1 / self._p_1)
|
||||
|
||||
|
||||
class Dropout2d(Module):
|
||||
@@ -85,7 +85,7 @@ class Dropout2d(Module):
|
||||
mask_shape[-2] = mask_shape[-3] = 1
|
||||
|
||||
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||
return (1 / self._p_1) * mask * x
|
||||
return (mask * x) * (1 / self._p_1)
|
||||
|
||||
|
||||
class Dropout3d(Module):
|
||||
@@ -134,4 +134,4 @@ class Dropout3d(Module):
|
||||
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
|
||||
|
||||
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||
return (1 / self._p_1) * mask * x
|
||||
return (mask * x) * (1 / self._p_1)
|
||||
|
@@ -190,9 +190,9 @@ class MaxPool1d(_Pool1d):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Optional[Union[int, Tuple[int]]] = None,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
):
|
||||
super().__init__(mx.max, -float("inf"), kernel_size, stride, padding)
|
||||
|
||||
@@ -229,9 +229,9 @@ class AvgPool1d(_Pool1d):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
padding: Optional[Union[int, Tuple[int, int]]] = 0,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: Optional[Union[int, Tuple[int]]] = None,
|
||||
padding: Union[int, Tuple[int]] = 0,
|
||||
):
|
||||
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
||||
|
||||
|
@@ -12,7 +12,7 @@ def quantize(
|
||||
model: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
class_predicate: Optional[callable] = None,
|
||||
class_predicate: Optional[Callable] = None,
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
|
||||
|
@@ -147,9 +147,9 @@ class TransformerEncoderLayer(Module):
|
||||
else:
|
||||
y = self.attention(x, x, x, mask)
|
||||
y = self.dropout1(y)
|
||||
y = self.ln1(x + y)
|
||||
x = self.ln1(x + y)
|
||||
|
||||
y = self.linear1(y)
|
||||
y = self.linear1(x)
|
||||
y = self.activation(y)
|
||||
y = self.dropout2(y)
|
||||
y = self.linear2(y)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -22,7 +22,7 @@ def _reduce(loss: mx.array, reduction: Reduction = "none"):
|
||||
def cross_entropy(
|
||||
logits: mx.array,
|
||||
targets: mx.array,
|
||||
weights: mx.array = None,
|
||||
weights: Optional[mx.array] = None,
|
||||
axis: int = -1,
|
||||
label_smoothing: float = 0.0,
|
||||
reduction: Reduction = "none",
|
||||
@@ -117,7 +117,7 @@ def cross_entropy(
|
||||
def binary_cross_entropy(
|
||||
inputs: mx.array,
|
||||
targets: mx.array,
|
||||
weights: mx.array = None,
|
||||
weights: Optional[mx.array] = None,
|
||||
with_logits: bool = True,
|
||||
reduction: Reduction = "mean",
|
||||
) -> mx.array:
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -37,7 +37,7 @@ def value_and_grad(model: Module, fn: Callable):
|
||||
return wrapped_value_grad_fn
|
||||
|
||||
|
||||
def checkpoint(module: Module, fn: Callable = None):
|
||||
def checkpoint(module: Module, fn: Optional[Callable] = None):
|
||||
"""Transform the passed callable to one that performs gradient
|
||||
checkpointing with respect to the trainable parameters of the module (and
|
||||
the callable's inputs).
|
||||
|
@@ -4,6 +4,7 @@ import math
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn import Module
|
||||
from mlx.utils import tree_map, tree_reduce
|
||||
|
||||
|
||||
@@ -17,7 +18,7 @@ class Optimizer:
|
||||
self._state = {"step": mx.array(0, mx.uint64)}
|
||||
self._schedulers = {k: v for k, v in (schedulers or {}).items()}
|
||||
|
||||
def update(self, model: "mlx.nn.Module", gradients: dict):
|
||||
def update(self, model: Module, gradients: dict):
|
||||
"""Apply the gradients to the parameters of the model and update the
|
||||
model with the new parameters.
|
||||
|
||||
@@ -48,8 +49,28 @@ class Optimizer:
|
||||
>>> optimizer.state.keys()
|
||||
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
|
||||
"""
|
||||
self._state.update(tree_map(lambda x: {}, parameters))
|
||||
tree_map(self.init_single, parameters, self._state)
|
||||
|
||||
# Iniatilize the optimizer state to match the parameter state
|
||||
def update_state(params, state):
|
||||
if isinstance(params, (list, tuple)):
|
||||
state = list(state)
|
||||
for i in range(len(state)):
|
||||
state[i] = update_state(params[i], state[i])
|
||||
if len(state) != len(params):
|
||||
state.extend(tree_map(lambda x: {}, params[len(state) :]))
|
||||
return type(params)(state)
|
||||
elif isinstance(params, dict):
|
||||
for k, v in params.items():
|
||||
if k not in state:
|
||||
state[k] = tree_map(lambda x: {}, v)
|
||||
else:
|
||||
state[k] = update_state(v, state[k])
|
||||
return state
|
||||
else:
|
||||
return state
|
||||
|
||||
update_state(parameters, self._state)
|
||||
tree_map(lambda p, s: s or self.init_single(p, s), parameters, self._state)
|
||||
self._initialized = True
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
@@ -104,7 +125,7 @@ class Optimizer:
|
||||
|
||||
@state.setter
|
||||
def state(self, state: dict):
|
||||
self._initialized = True
|
||||
self._initialized = False
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
|
@@ -1,10 +1,10 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Tuple
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
|
||||
def tree_map(
|
||||
fn: Callable, tree: Any, *rest: Tuple[Any], is_leaf: Callable = None
|
||||
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None
|
||||
) -> Any:
|
||||
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
@@ -59,8 +59,8 @@ def tree_map(
|
||||
def tree_map_with_path(
|
||||
fn: Callable,
|
||||
tree: Any,
|
||||
*rest: Tuple[Any],
|
||||
is_leaf: Callable = None,
|
||||
*rest: Any,
|
||||
is_leaf: Optional[Callable] = None,
|
||||
path: Any = None,
|
||||
) -> Any:
|
||||
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
|
||||
@@ -111,7 +111,9 @@ def tree_map_with_path(
|
||||
return fn(path, tree, *rest)
|
||||
|
||||
|
||||
def tree_flatten(tree, prefix="", is_leaf=None):
|
||||
def tree_flatten(
|
||||
tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None
|
||||
) -> Any:
|
||||
"""Flattens a Python tree to a list of key, value tuples.
|
||||
|
||||
The keys are using the dot notation to define trees of arbitrary depth and
|
||||
@@ -155,7 +157,7 @@ def tree_flatten(tree, prefix="", is_leaf=None):
|
||||
return [(prefix[1:], tree)]
|
||||
|
||||
|
||||
def tree_unflatten(tree):
|
||||
def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any:
|
||||
"""Recreate a Python tree from its flat representation.
|
||||
|
||||
.. code-block:: python
|
||||
|
@@ -9,6 +9,7 @@
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
#include <nanobind/typing.h>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "python/src/buffer.h"
|
||||
@@ -113,6 +114,7 @@ void init_array(nb::module_& m) {
|
||||
.def("__hash__", [](const Dtype& t) {
|
||||
return static_cast<int64_t>(t.val);
|
||||
});
|
||||
|
||||
m.attr("bool_") = nb::cast(bool_);
|
||||
m.attr("uint8") = nb::cast(uint8);
|
||||
m.attr("uint16") = nb::cast(uint16);
|
||||
@@ -177,7 +179,7 @@ void init_array(nb::module_& m) {
|
||||
.export_values();
|
||||
nb::class_<ArrayAt>(
|
||||
m,
|
||||
"_ArrayAt",
|
||||
"ArrayAt",
|
||||
R"pbdoc(
|
||||
A helper object to apply updates at specific indices.
|
||||
)pbdoc")
|
||||
@@ -195,7 +197,7 @@ void init_array(nb::module_& m) {
|
||||
|
||||
nb::class_<ArrayPythonIterator>(
|
||||
m,
|
||||
"_ArrayIterator",
|
||||
"ArrayIterator",
|
||||
R"pbdoc(
|
||||
A helper object to iterate over the 1st dimension of an array.
|
||||
)pbdoc")
|
||||
@@ -840,6 +842,8 @@ void init_array(nb::module_& m) {
|
||||
},
|
||||
"other"_a,
|
||||
nb::rv_policy::none)
|
||||
.def("__int__", [](array& a) { return nb::int_(to_scalar(a)); })
|
||||
.def("__float__", [](array& a) { return nb::float_(to_scalar(a)); })
|
||||
.def(
|
||||
"flatten",
|
||||
[](const array& a,
|
||||
|
@@ -160,6 +160,10 @@ nb::ndarray<> mlx_to_dlpack(const array& a) {
|
||||
}
|
||||
|
||||
nb::object to_scalar(array& a) {
|
||||
if (a.size() != 1) {
|
||||
throw std::invalid_argument(
|
||||
"[convert] Only length-1 arrays can be converted to Python scalars.");
|
||||
}
|
||||
{
|
||||
nb::gil_scoped_release nogil;
|
||||
a.eval();
|
||||
|
@@ -3,6 +3,8 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
@@ -74,8 +76,9 @@ void init_distributed(nb::module_& parent_module) {
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
"def all_sum(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce sum.
|
||||
|
||||
@@ -86,6 +89,8 @@ void init_distributed(nb::module_& parent_module) {
|
||||
group (Group): The group of processes that will participate in the
|
||||
reduction. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The sum of all ``x`` arrays.
|
||||
@@ -97,8 +102,9 @@ void init_distributed(nb::module_& parent_module) {
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_gather(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
"def all_gather(x: array, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Gather arrays from all processes.
|
||||
|
||||
@@ -110,8 +116,96 @@ void init_distributed(nb::module_& parent_module) {
|
||||
group (Group): The group of processes that will participate in the
|
||||
gather. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The concatenation of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"send",
|
||||
&distributed::send,
|
||||
"x"_a,
|
||||
"dst"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def send(x: array, dst: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Send an array from the current process to the process that has rank
|
||||
``dst`` in the group.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
dst (int): Rank of the destination process in the group.
|
||||
group (Group): The group of processes that will participate in the
|
||||
sned. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: An empty array which when evaluated the send is performed.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"recv",
|
||||
&distributed::recv,
|
||||
"shape"_a,
|
||||
"dtype"_a,
|
||||
"src"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Recv an array with shape ``shape`` and dtype ``dtype`` from process
|
||||
with rank ``src``.
|
||||
|
||||
Args:
|
||||
shape (Tuple[int]): The shape of the array we are receiving.
|
||||
dtype (Dtype): The data type of the array we are receiving.
|
||||
src (int): Rank of the source process in the group.
|
||||
group (Group): The group of processes that will participate in the
|
||||
recv. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The array that was received from ``src``.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"recv_like",
|
||||
&distributed::recv_like,
|
||||
"x"_a,
|
||||
"src"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def recv_like(x: array, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Recv an array with shape and type like ``x`` from process with rank
|
||||
``src``.
|
||||
|
||||
It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``.
|
||||
|
||||
Args:
|
||||
x (array): An array defining the shape and dtype of the array we are
|
||||
receiving.
|
||||
src (int): Rank of the source process in the group.
|
||||
group (Group): The group of processes that will participate in the
|
||||
recv. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The array that was received from ``src``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@@ -1,9 +1,14 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/map.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/tuple.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -79,25 +84,29 @@ void init_fast(nb::module_& parent_module) {
|
||||
"dims"_a,
|
||||
nb::kw_only(),
|
||||
"traditional"_a,
|
||||
"base"_a,
|
||||
"base"_a.none(),
|
||||
"scale"_a,
|
||||
"offset"_a,
|
||||
"freqs"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def rope(a: array, dims: int, *, traditional: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: int, freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Apply rotary positional encoding to the input.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
dims (int): The feature dimensions to be rotated. If the input feature
|
||||
is larger than dims then the rest is left unchanged.
|
||||
is larger than dims then the rest is left unchanged.
|
||||
traditional (bool): If set to ``True`` choose the traditional
|
||||
implementation which rotates consecutive dimensions.
|
||||
base (float): The base used to compute angular frequency for
|
||||
each dimension in the positional encodings.
|
||||
implementation which rotates consecutive dimensions.
|
||||
base (float, optional): The base used to compute angular frequency for
|
||||
each dimension in the positional encodings. Exactly one of ``base`` and
|
||||
``freqs`` must be ``None``.
|
||||
scale (float): The scale used to scale the positions.
|
||||
offset (int): The position offset to start at.
|
||||
freqs (array, optional): Optional frequencies to use with RoPE.
|
||||
If set, the ``base`` parameter must be ``None``. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
array: The output array.
|
||||
@@ -112,9 +121,10 @@ void init_fast(nb::module_& parent_module) {
|
||||
nb::kw_only(),
|
||||
"scale"_a,
|
||||
"mask"_a = nb::none(),
|
||||
"memory_efficient_threshold"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
||||
|
||||
@@ -182,4 +192,153 @@ void init_fast(nb::module_& parent_module) {
|
||||
Returns:
|
||||
array: The quantized version of ``w``
|
||||
)pbdoc");
|
||||
|
||||
nb::class_<fast::MetalKernel>(
|
||||
m,
|
||||
"metal_kernel",
|
||||
R"pbdoc(
|
||||
A jit-compiled custom Metal kernel defined from a source string.
|
||||
)pbdoc")
|
||||
.def(
|
||||
nb::init<
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool,
|
||||
bool>(),
|
||||
"name"_a,
|
||||
"source"_a,
|
||||
"header"_a = "",
|
||||
"ensure_row_contiguous"_a = true,
|
||||
"atomic_outputs"_a = false,
|
||||
R"pbdoc(
|
||||
Initialize a metal_kernel.
|
||||
|
||||
Args:
|
||||
name (str): Name for the kernel.
|
||||
source (str): Source code. This is the body of a function in Metal,
|
||||
the function signature will be generated for you. The names of the inputs/outputs
|
||||
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
|
||||
used when the kernel is called.
|
||||
header (str): Header source code to include before the main function.
|
||||
Useful for helper functions or includes that should live outside of the main function body.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``True``.
|
||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||
e.g. ``device atomic<float>``. Default: ``False``.
|
||||
Returns:
|
||||
Callable ``metal_kernel``.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = '''
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
'''
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
verbose=True,
|
||||
)
|
||||
return outputs["out"]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
)pbdoc")
|
||||
.def(
|
||||
"__call__",
|
||||
[](fast::MetalKernel& kernel,
|
||||
std::map<std::string, ScalarOrArray>& inputs_,
|
||||
std::map<std::string, std::vector<int>>& output_shapes,
|
||||
std::map<std::string, Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, nb::handle>> template_args_,
|
||||
std::optional<float> init_value,
|
||||
bool verbose,
|
||||
StreamOrDevice s) {
|
||||
std::map<std::string, array> inputs;
|
||||
for (const auto& [name, value] : inputs_) {
|
||||
auto arr = to_array(value, std::nullopt);
|
||||
inputs.insert({name, arr});
|
||||
}
|
||||
std::map<std::string, fast::TemplateArg> template_args;
|
||||
if (template_args_) {
|
||||
for (const auto& [name, value] : template_args_.value()) {
|
||||
// Handle bool, int and dtype template args
|
||||
if (nb::isinstance<bool>(value)) {
|
||||
bool bool_val = nb::cast<bool>(value);
|
||||
template_args.insert({name, bool_val});
|
||||
} else if (nb::isinstance<int>(value)) {
|
||||
int int_val = nb::cast<int>(value);
|
||||
template_args.insert({name, int_val});
|
||||
} else if (nb::isinstance<Dtype>(value)) {
|
||||
Dtype dtype = nb::cast<Dtype>(value);
|
||||
template_args.insert({name, dtype});
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[[metal_kernel]] Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return kernel(
|
||||
inputs,
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
grid,
|
||||
threadgroup,
|
||||
template_args,
|
||||
init_value,
|
||||
verbose,
|
||||
s);
|
||||
},
|
||||
nb::kw_only(),
|
||||
"inputs"_a,
|
||||
"output_shapes"_a,
|
||||
"output_dtypes"_a,
|
||||
"grid"_a,
|
||||
"threadgroup"_a,
|
||||
"template"_a = nb::none(),
|
||||
"init_value"_a = nb::none(),
|
||||
"verbose"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
R"pbdoc(
|
||||
Run the kernel.
|
||||
|
||||
Args:
|
||||
inputs (Mapping[str, array]): Inputs. These will be added to the function signature and passed to the Metal kernel.
|
||||
The keys will be the names of the arguments to the kernel.
|
||||
output_shapes (Mapping[str, Sequence[int]]): Output shapes. A dict mapping
|
||||
output variable names to shapes. These will be added to the function signature.
|
||||
output_dtypes (Mapping[str, Dtype]): Output dtypes. A dict mapping output variable
|
||||
names to dtypes. Must have the same keys as ``output_shapes``.
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
dict[str, array]: Dictionary of output arrays based on ``output_shapes``/``output_dtypes``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@@ -63,7 +63,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def norm(a: array, /, ord: Union[None, int, float, str] = None, axis: Union[None, int, list[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Matrix or vector norm.
|
||||
|
||||
@@ -74,7 +74,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
|
||||
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
|
||||
2-norm of ``a.flatten`` will be returned.
|
||||
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
|
||||
ord (int, float or str, optional): Order of the norm (see table under ``Notes``).
|
||||
If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
|
||||
along the given ``axis``. Default: ``None``.
|
||||
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
|
||||
@@ -187,7 +187,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"),
|
||||
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array)"),
|
||||
R"pbdoc(
|
||||
The QR factorization of the input matrix.
|
||||
|
||||
@@ -220,7 +220,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"),
|
||||
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array, array)"),
|
||||
R"pbdoc(
|
||||
The Singular Value Decomposition (SVD) of the input matrix.
|
||||
|
||||
@@ -325,9 +325,9 @@ void init_linalg(nb::module_& parent_module) {
|
||||
nb::sig(
|
||||
"def cholesky_inv(L: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition L.
|
||||
Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition.
|
||||
|
||||
Let A be a real symmetric positive semi-definite matrix and L its Cholesky definition such that:
|
||||
Let :math:`\mathbf{A}` be a real symmetric positive semi-definite matrix and :math:`\mathbf{L}` its Cholesky decomposition such that:
|
||||
|
||||
.. math::
|
||||
|
||||
@@ -339,7 +339,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
|
||||
This function supports arrays with at least 2 dimensions. When the input
|
||||
has more than two dimensions, the Cholesky inverse is computed for each matrix
|
||||
in the last two dimensions of ``L``.
|
||||
in the last two dimensions of :math:`\mathbf{L}`.
|
||||
|
||||
If the input matrix is not a triangular matrix behaviour is undefined.
|
||||
|
||||
@@ -351,6 +351,30 @@ void init_linalg(nb::module_& parent_module) {
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: :math:`A^{-1}` where :math:`\mathbf{A} = \mathbf{L}\mathbf{L}^T`.
|
||||
array: :math:`\mathbf{A^{-1}}` where :math:`\mathbf{A} = \mathbf{L}\mathbf{L}^T`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"pinv",
|
||||
&pinv,
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def pinv(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the (Moore-Penrose) pseudo-inverse of a matrix.
|
||||
|
||||
This function calculates a generalized inverse of a matrix using its
|
||||
singular-value decomposition. This function supports arrays with at least 2 dimensions.
|
||||
When the input has more than two dimensions, the inverse is computed for each
|
||||
matrix in the last two dimensions of ``a``.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: ``aplus`` such that ``a @ aplus @ a = a``
|
||||
)pbdoc");
|
||||
}
|
||||
|
@@ -138,6 +138,21 @@ class PyFileReader : public io::Reader {
|
||||
|
||||
void read(char* data, size_t n) override {
|
||||
nb::gil_scoped_acquire gil;
|
||||
_read(data, n);
|
||||
}
|
||||
|
||||
void read(char* data, size_t n, size_t offset) override {
|
||||
nb::gil_scoped_acquire gil;
|
||||
seek_func_(offset, (int)std::ios_base::beg);
|
||||
_read(data, n);
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
return "python file object";
|
||||
}
|
||||
|
||||
private:
|
||||
void _read(char* data, size_t n) {
|
||||
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE);
|
||||
nb::object bytes_read = readinto_func_(nb::handle(memview));
|
||||
|
||||
@@ -146,11 +161,6 @@ class PyFileReader : public io::Reader {
|
||||
}
|
||||
}
|
||||
|
||||
std::string label() const override {
|
||||
return "python file object";
|
||||
}
|
||||
|
||||
private:
|
||||
nb::object pyistream_;
|
||||
nb::object readinto_func_;
|
||||
nb::object seek_func_;
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user