Compare commits

..

3 Commits

Author SHA1 Message Date
Angelos Katharopoulos
5b46b9bc52 Add priority-queue eval 2024-01-13 19:58:20 -08:00
Awni Hannun
fd94be28ea fix test + choose stream with slight care 2024-01-13 13:34:27 -08:00
Awni Hannun
9051fa1eaa Use a dummy primitive to only sync with one output 2024-01-13 13:08:19 -08:00
152 changed files with 2881 additions and 10171 deletions

View File

@@ -26,28 +26,18 @@ jobs:
command: |
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libblas-dev
- run:
name: Install Python package
name: Build python package
command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
- run:
name: Generate package stubs
name: Run the python tests
command: |
python3 setup.py generate_stubs
- run:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3 -m pip install .
python3 -m unittest discover python/tests
- run:
name: Build CPP only
command: |
@@ -70,38 +60,24 @@ jobs:
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
pip install torch
pip install tensorflow
pip install unittest-xml-reporting
- run:
name: Install Python package
name: Build python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
- run:
name: Generate package stubs
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
python setup.py generate_stubs
- run:
name: Run Python tests
name: Run the python tests
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# eval "$(conda shell.bash hook)"
# conda activate runner-env
# cd examples/extensions && python -m pip install .
- store_test_results:
path: test-results
- run:
@@ -133,27 +109,10 @@ jobs:
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
pip install twine
# TODO: Update build system to switch away from setup.py develop
- run:
name: Install Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
PYPI_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py develop
- run:
name: Generate package stubs
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
python setup.py generate_stubs
- run:
name: Publish Python package
name: Build package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
@@ -186,26 +145,10 @@ jobs:
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
pip install twine
- run:
name: Install Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py develop
- run:
name: Generate package stubs
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
python setup.py generate_stubs
- run:
name: Publish Python package
name: Build package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
@@ -238,25 +181,10 @@ jobs:
conda activate runner-env
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install numpy
pip install twine
- run:
name: Install Python package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python setup.py develop
- run:
name: Generate package stubs
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env
python setup.py generate_stubs
- run:
name: Build package distribution
name: Build package
command: |
eval "$(conda shell.bash hook)"
conda activate runner-env

View File

@@ -9,7 +9,7 @@ repos:
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 5.12.0
hooks:
- id: isort
args:

View File

@@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.1.0)
set(MLX_VERSION 0.0.9)
endif()
# --------------------- Processor tests -------------------------
@@ -31,13 +31,13 @@ if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, check the build"
" documentation for possible fixes: "
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
message(WARNING
"Building for x86_64 on macOS is not supported."
message(WARNING
"Building for x86_64 on macOS is not supported."
" If you are on an Apple silicon system, "
" make sure you are building for arm64.")
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
@@ -75,7 +75,7 @@ elseif (MLX_BUILD_METAL)
COMMAND_ERROR_IS_FATAL ANY)
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
@@ -123,27 +123,16 @@ else()
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib" ${BLAS_LIBRARIES})
message(STATUS "Blas incclude" ${BLAS_INCLUDE_DIRS})
message(STATUS ${BLAS_LIBRARIES})
message(STATUS ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include)
message(STATUS "Lapack lib" ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
mlx
mlx
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>

View File

@@ -1,4 +1,3 @@
include CMakeLists.txt
recursive-include mlx/ *
include python/src/*
python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -68,18 +68,10 @@ in the documentation.
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
**With `pip`**:
```
pip install mlx
```
**With `conda`**:
```
conda install -c conda-forge mlx
```
Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source.

View File

@@ -166,13 +166,13 @@ if __name__ == "__main__":
dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
(1, 64, 64, 25344),
(16, 1024, 1024, 1024),
(1, 1024, 1024, 2048),
(4, 1024, 1024, 4096),
(4, 1024, 4096, 1024),
(1, 4096, 4096, 4096),
(15, 1023, 1023, 1023),
(17, 1025, 1025, 1025),
)
for dtype in dtypes:

View File

@@ -60,60 +60,20 @@ def matmul(x, y):
mx.eval(ys)
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
def _quant_matmul(x, w, s, b, group_size, bits):
ys = []
for i in range(10):
ys.append(
mx.quantized_matmul(
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
)
)
ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits))
mx.eval(ys)
quant_matmul = {
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
"quant_matmul_128_2": partial(
_quant_matmul, transpose=False, group_size=128, bits=2
),
"quant_matmul_128_4": partial(
_quant_matmul, transpose=False, group_size=128, bits=4
),
"quant_matmul_128_8": partial(
_quant_matmul, transpose=False, group_size=128, bits=8
),
"quant_matmul_t_32_2": partial(
_quant_matmul, transpose=True, group_size=32, bits=2
),
"quant_matmul_t_32_4": partial(
_quant_matmul, transpose=True, group_size=32, bits=4
),
"quant_matmul_t_32_8": partial(
_quant_matmul, transpose=True, group_size=32, bits=8
),
"quant_matmul_t_64_2": partial(
_quant_matmul, transpose=True, group_size=64, bits=2
),
"quant_matmul_t_64_4": partial(
_quant_matmul, transpose=True, group_size=64, bits=4
),
"quant_matmul_t_64_8": partial(
_quant_matmul, transpose=True, group_size=64, bits=8
),
"quant_matmul_t_128_2": partial(
_quant_matmul, transpose=True, group_size=128, bits=2
),
"quant_matmul_t_128_4": partial(
_quant_matmul, transpose=True, group_size=128, bits=4
),
"quant_matmul_t_128_8": partial(
_quant_matmul, transpose=True, group_size=128, bits=8
),
"quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8),
"quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2),
"quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4),
"quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8),
}
@@ -269,13 +229,6 @@ def linear(w, b, x):
mx.eval(ys)
def linear_fused(w, b, x):
ys = []
for i in range(10):
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
mx.eval(ys)
def rope(x):
*_, N, D = x.shape
ys = []
@@ -416,10 +369,7 @@ if __name__ == "__main__":
print(bench(quant_matmul[args.benchmark], *xs))
elif args.benchmark == "linear":
if args.fused:
print(bench(linear_fused, *xs))
else:
print(bench(linear, *xs))
print(bench(linear, *xs))
elif args.benchmark == "sum_axis":
print(bench(reduction, "sum", axis, x))

View File

@@ -0,0 +1,198 @@
# Copyright © 2023 Apple Inc.
import math
import time
import jax
import jax.numpy as jnp
from flax import linen as nn
class RoPE(nn.Module):
dims: int
traditional: bool = False
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = jnp.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
dtype=jnp.float32,
):
D = D // 2
positions = jnp.arange(offset, N, dtype=dtype)
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)
return costheta, sintheta
@nn.compact
def __call__(self, x, offset: int = 0):
shape = x.shape
x = x.reshape((-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.reshape(shape)
class LlamaAttention(nn.Module):
dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
num_heads = self.num_heads
dims = self.dims
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = jnp.concatenate([key_cache, keys], axis=2)
values = jnp.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = jax.nn.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
dims: int
mlp_dims: int
num_heads: int
dtype: jnp.dtype
def setup(self):
dims = self.dims
mlp_dims = self.mlp_dims
num_heads = self.num_heads
self.attention = LlamaAttention(dims, num_heads, dtype)
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = jax.nn.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
jax.block_until_ready((y, c))
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
dtype = jnp.float16
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
x = jax.random.normal(k1, (1, 1, D), dtype)
cache = [
jax.random.normal(k2, [1, H, C, D // H], dtype),
jax.random.normal(k3, [1, H, C, D // H], dtype),
]
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
params = layer.init(k4, x, mask=None, cache=cache)["params"]
@jax.jit
def model_fn(x, mask, cache):
return layer.apply({"params": params}, x, mask=mask, cache=cache)
T = measure(model_fn, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,118 @@
# Copyright © 2023 Apple Inc.
import math
import time
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, False)
self.key_proj = nn.Linear(dims, dims, False)
self.value_proj = nn.Linear(dims, dims, False)
self.out_proj = nn.Linear(dims, dims, False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, False)
self.linear2 = nn.Linear(dims, mlp_dims, False)
self.linear3 = nn.Linear(mlp_dims, dims, False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
mx.eval(y, c)
start = time.time()
rs = []
for i in range(5):
y, c = model(x, mask=None, cache=cache)
rs.append((y, c))
mx.eval(rs)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
mx.set_default_device(mx.gpu)
dtype = mx.float16
layer = LlamaEncoderLayer(D, F, H)
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
x = mx.random.normal([1, 1, D], dtype=dtype)
cache = [
mx.random.normal([1, H, C, D // H], dtype=dtype),
mx.random.normal([1, H, C, D // H], dtype=dtype),
]
mx.eval(x, cache)
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -0,0 +1,199 @@
# Copyright © 2023 Apple Inc.
import math
import time
import torch
import torch.mps
import torch.nn as nn
def sync_if_needed(x):
if x.device != torch.device("cpu"):
torch.mps.synchronize()
class RoPE(nn.Module):
def __init__(self, dims: int, traditional: bool = False):
super().__init__()
self.dims = dims
self.traditional = traditional
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
else:
rx = torch.cat([rx1, rx2], dim=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
return rx
def forward(self, x, offset: int = 0):
shape = x.shape
x = x.view(-1, shape[-2], shape[-1])
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return rx.view(*shape)
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
device="cpu",
dtype=torch.float32,
):
D = D // 2
positions = torch.arange(offset, N, dtype=dtype, device=device)
freqs = torch.exp(
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
)
theta = positions.view(-1, 1) * freqs.view(1, -1)
costheta = torch.cos(theta)
sintheta = torch.sin(theta)
return costheta, sintheta
class RMSNorm(nn.Module):
def __init__(self, dims: int, epsilon: float = 1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones((dims,)))
self.epsilon = epsilon
def forward(self, x):
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
return self.gamma * x * n
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def forward(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = torch.cat([key_cache, keys], dim=2)
values = torch.cat([value_cache, values], dim=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = torch.softmax(scores, dim=-1)
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = RMSNorm(dims)
self.norm2 = RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def forward(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = torch.nn.functional.silu(a) * b
y = self.linear3(y)
x = x + y
return x, cache
@torch.no_grad()
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
start = time.time()
for i in range(5):
y, c = model(x, mask=None, cache=cache)
sync_if_needed(x)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
device = torch.device("mps")
dtype = torch.float16
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
x = torch.randn(1, 1, D).to(device).to(dtype)
cache = [
torch.randn(1, H, C, D // H).to(device).to(dtype),
torch.randn(1, H, C, D // H).to(device).to(dtype),
]
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

View File

@@ -44,13 +44,6 @@ def time_matmul():
time_fn(mx.matmul, a, b)
def time_maximum():
a = mx.random.uniform(shape=(32, 1024, 1024))
b = mx.random.uniform(shape=(32, 1024, 1024))
mx.eval(a, b)
time_fn(mx.maximum, a, b)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)
@@ -108,7 +101,6 @@ if __name__ == "__main__":
time_add()
time_matmul()
time_maximum()
time_exp()
time_negative()
time_logsumexp()

View File

@@ -5,15 +5,13 @@
import os
import subprocess
import mlx.core as mx
# -- Project information -----------------------------------------------------
project = "MLX"
copyright = "2023, MLX Contributors"
author = "MLX Contributors"
version = ".".join(mx.__version__.split(".")[:3])
release = version
version = "0.0.9"
release = "0.0.9"
# -- General configuration ---------------------------------------------------

View File

@@ -929,7 +929,7 @@ We see some modest improvements right away!
This operation is now good to be used to build other operations,
in :class:`mlx.nn.Module` calls, and also as a part of graph
transformations like :meth:`grad`!
transformations such as :meth:`grad` and :meth:`simplify`!
Scripts
-------

View File

@@ -40,7 +40,6 @@ are the CPU and GPU.
usage/unified_memory
usage/indexing
usage/saving_and_loading
usage/function_transforms
usage/numpy
usage/using_streams

View File

@@ -9,4 +9,3 @@ Linear Algebra
:toctree: _autosummary
norm
qr

View File

@@ -180,4 +180,3 @@ In detail:
nn/layers
nn/functions
nn/losses
nn/init

View File

@@ -15,10 +15,9 @@ simple functions.
gelu
gelu_approx
gelu_fast_approx
mish
prelu
relu
selu
softshrink
prelu
silu
step
selu
mish

View File

@@ -1,45 +0,0 @@
.. _init:
.. currentmodule:: mlx.nn.init
Initializers
------------
The ``mlx.nn.init`` package contains commonly used initializers for neural
network parameters. Initializers return a function which can be applied to any
input :obj:`mlx.core.array` to produce an initialized output.
For example:
.. code:: python
import mlx.core as mx
import mlx.nn as nn
init_fn = nn.init.uniform()
# Produces a [2, 2] uniform matrix
param = init_fn(mx.zeros((2, 2)))
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
distribution, you can do:
.. code:: python
import mlx.nn as nn
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
init_fn = nn.init.uniform(low=-0.1, high=0.1)
model.apply(init_fn)
.. autosummary::
:toctree: _autosummary
constant
normal
uniform
identity
glorot_normal
glorot_uniform
he_normal
he_uniform

View File

@@ -9,30 +9,29 @@ Layers
:toctree: _autosummary
:template: nn-module-template.rst
ALiBi
BatchNorm
Sequential
ReLU
PReLU
GELU
SiLU
Step
SELU
Mish
Embedding
Linear
QuantizedLinear
Conv1d
Conv2d
BatchNorm
LayerNorm
RMSNorm
GroupNorm
InstanceNorm
Dropout
Dropout2d
Dropout3d
Embedding
GELU
GroupNorm
InstanceNorm
LayerNorm
Linear
Mish
MultiHeadAttention
PReLU
QuantizedLinear
RMSNorm
ReLU
RoPE
SELU
Sequential
SiLU
SinusoidalPositionalEncoding
Softshrink
Step
Transformer
MultiHeadAttention
ALiBi
RoPE
SinusoidalPositionalEncoding

View File

@@ -10,15 +10,14 @@ Loss Functions
:template: nn-module-template.rst
binary_cross_entropy
cosine_similarity_loss
cross_entropy
gaussian_nll_loss
hinge_loss
huber_loss
kl_div_loss
l1_loss
log_cosh_loss
mse_loss
nll_loss
smooth_l1_loss
triplet_loss
triplet_loss
hinge_loss
huber_loss
log_cosh_loss
cosine_similarity_loss

View File

@@ -35,8 +35,6 @@ Operations
cos
cosh
dequantize
diag
diagonal
divide
divmod
equal
@@ -54,9 +52,6 @@ Operations
identity
inner
isnan
isposinf
isneginf
isinf
less
less_equal
linspace

View File

@@ -40,7 +40,6 @@ model's parameters and the **optimizer state**.
SGD
RMSprop
Adagrad
Adafactor
AdaDelta
Adam
AdamW

View File

@@ -33,13 +33,13 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
.. autosummary::
:toctree: _autosummary
seed
key
split
bernoulli
categorical
gumbel
key
normal
randint
seed
split
truncated_normal
uniform
truncated_normal

View File

@@ -14,3 +14,4 @@ Transforms
jvp
vjp
vmap
simplify

View File

@@ -1,188 +0,0 @@
.. _function_transforms:
Function Transforms
===================
.. currentmodule:: mlx.core
MLX uses composable function transformations for automatic differentiation and
vectorization. The key idea behind composable function transformations is that
every transformation returns a function which can be further transformed.
Here is a simple example:
.. code-block:: shell
>>> dfdx = mx.grad(mx.sin)
>>> dfdx(mx.array(mx.pi))
array(-1, dtype=float32)
>>> mx.cos(mx.array(mx.pi))
array(-1, dtype=float32)
The output of :func:`grad` on :func:`sin` is simply another function. In this
case it is the gradient of the sine function which is exactly the cosine
function. To get the second derivative you can do:
.. code-block:: shell
>>> d2fdx2 = mx.grad(mx.grad(mx.sin))
>>> d2fdx2(mx.array(mx.pi / 2))
array(-1, dtype=float32)
>>> mx.sin(mx.array(mx.pi / 2))
array(1, dtype=float32)
Using :func:`grad` on the output of :func:`grad` is always ok. You keep
getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
depth. To see the complete list of function transformations check-out the
:ref:`API documentation <transforms>`. See the following sections for more
information on :ref:`automatic differentiaion <auto diff>` and
:ref:`automatic vectorization <vmap>`.
Automatic Differentiation
-------------------------
.. _auto diff:
Automatic differentiation in MLX works on functions rather than on implicit
graphs.
.. note::
If you are coming to MLX from PyTorch, you no longer need functions like
``backward``, ``zero_grad``, and ``detach``, or properties like
``requires_grad``.
The most basic example is taking the gradient of a scalar-valued function as we
saw above. You can use the :func:`grad` and :func:`value_and_grad` function to
compute gradients of more complex functions. By default these functions compute
the gradient with respect to the first argument:
.. code-block:: python
def loss_fn(w, x, y):
return mx.mean(mx.square(w * x - y))
w = mx.array(1.0)
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to w:
grad_fn = mx.grad(loss_fn)
dloss_dw = grad_fn(w, x, y)
# Prints array(-1, dtype=float32)
print(dloss_dw)
# To get the gradient with respect to x we can do:
grad_fn = mx.grad(loss_fn, argnums=1)
dloss_dx = grad_fn(w, x, y)
# Prints array([-1, 1], dtype=float32)
print(dloss_dx)
One way to get the loss and gradient is to call ``loss_fn`` followed by
``grad_fn``, but this can result in a lot of redundant work. Instead, you
should use :func:`value_and_grad`. Continuing the above example:
.. code-block:: python
# Computes the gradient of loss_fn with respect to w:
loss_and_grad_fn = mx.value_and_grad(loss_fn)
loss, dloss_dw = loss_and_grad_fn(w, x, y)
# Prints array(1, dtype=float32)
print(loss)
# Prints array(-1, dtype=float32)
print(dloss_dw)
You can also take the gradient with respect to arbitrarily nested Python
containers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or
:obj:`dict`).
Suppose we wanted a weight and a bias parameter in the above example. A nice
way to do that is the following:
.. code-block:: python
def loss_fn(params, x, y):
w, b = params["weight"], params["bias"]
h = w * x + b
return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
x = mx.array([0.5, -0.5])
y = mx.array([1.5, -1.5])
# Computes the gradient of loss_fn with respect to both the
# weight and bias:
grad_fn = mx.grad(loss_fn)
grads = grad_fn(params, x, y)
# Prints
# {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}
print(grads)
Notice the tree structure of the parameters is preserved in the gradients.
In some cases you may want to stop gradients from propagating through a
part of the function. You can use the :func:`stop_gradient` for that.
Automatic Vectorization
-----------------------
.. _vmap:
Use :func:`vmap` to automate vectorizing complex functions. Here we'll go
through a basic and contrived example for the sake of clarity, but :func:`vmap`
can be quite powerful for more complex functions which are difficult to optimize
by hand.
.. warning::
Some operations are not yet supported with :func:`vmap`. If you encounter an error
like: ``ValueError: Primitive's vmap not implemented.`` file an `issue
<https://github.com/ml-explore/mlx/issues>`_ and include your function.
We will prioritize including it.
A naive way to add the elements from two sets of vectors is with a loop:
.. code-block:: python
xs = mx.random.uniform(shape=(4096, 100))
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
Instead you can use :func:`vmap` to automatically vectorize the addition:
.. code-block:: python
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
where the vectorized axes should be in the outputs.
Let's time these two different versions:
.. code-block:: python
import timeit
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
vectorized version takes only ``0.025`` seconds, more than ten times faster.
Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

View File

@@ -20,7 +20,7 @@ Transforming Compute Graphs
Lazy evaluation let's us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations.
:func:`vmap` and graph optimizations like :func:`simplify`.
Currently, MLX does not compile and rerun compute graphs. They are all
generated dynamically. However, lazy evaluation makes it much easier to

View File

@@ -104,10 +104,7 @@ void axpby_impl(
}
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(
const std::vector<array>& inputs,
std::vector<array>& out_arr) {
auto out = out_arr[0];
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while constructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
@@ -178,10 +175,7 @@ void axpby_impl_accelerate(
}
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outarr) {
auto out = outarr[0];
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
@@ -195,15 +189,13 @@ void Axpby::eval_cpu(
}
// Fall back to common backend if specializations are not available
eval(inputs, outarr);
eval(inputs, out);
}
#else // Accelerate not available
/** Evaluate primitive on CPU falling back to common backend */
void Axpby::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& out) {
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
@@ -216,11 +208,8 @@ void Axpby::eval_cpu(
#ifdef _METAL_
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outarr) {
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
// Prepare inputs
auto out = outarr[0];
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
@@ -306,9 +295,7 @@ void Axpby::eval_gpu(
#else // Metal is not available
/** Fail evaluation on GPU */
void Axpby::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& out) {
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("Axpby has no GPU implementation.");
}
@@ -319,7 +306,7 @@ void Axpby::eval_gpu(
///////////////////////////////////////////////////////////////////////////////
/** The Jacobian-vector product. */
std::vector<array> Axpby::jvp(
array Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
@@ -334,33 +321,32 @@ std::vector<array> Axpby::jvp(
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())};
return multiply(scale_arr, tangents[0], stream());
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
}
}
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const array& cotan,
const std::vector<int>& argnums) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotangents[0].dtype());
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
auto scale_arr = array(scale, cotan.dtype());
vjps.push_back(multiply(scale_arr, cotan, stream()));
}
return vjps;
}
/** Vectorize primitive along given axis */
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation.");

View File

@@ -42,13 +42,11 @@ class Axpby : public Primitive {
* To avoid unnecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out)
override;
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
/** The Jacobian-vector product. */
std::vector<array> jvp(
array jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
@@ -56,9 +54,8 @@ class Axpby : public Primitive {
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
const array& cotan,
const std::vector<int>& argnums) override;
/**
* The primitive must know how to vectorize itself across
@@ -66,7 +63,7 @@ class Axpby : public Primitive {
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<std::vector<array>, std::vector<int>> vmap(
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
@@ -83,7 +80,7 @@ class Axpby : public Primitive {
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, std::vector<array>& out);
void eval(const std::vector<array>& inputs, array& out);
};
} // namespace mlx::core

View File

@@ -1,3 +0,0 @@
[build-system]
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"]
build-backend = "setuptools.build_meta"

View File

@@ -41,6 +41,6 @@ error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
throughput = num_iters / (toc - tic)
print(
f"Loss {loss.item():.5f}, L2 distance: |w-w*| = {error_norm:.5f}, "
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
f"Throughput {throughput:.5f} (it/s)"
)

View File

@@ -5,7 +5,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
@@ -20,7 +19,7 @@ target_sources(
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
else()
target_sources(

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <functional>
@@ -47,17 +47,6 @@ array::array(
std::move(primitive),
inputs)) {}
array::array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
std::move(shape),
dtype,
std::move(primitive),
std::move(inputs))) {}
std::vector<array> array::make_arrays(
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
@@ -93,17 +82,9 @@ array::array(
}
void array::detach() {
for (auto& s : array_desc_->siblings) {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->depth = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
array_desc_->depth = 0;
array_desc_->primitive = nullptr;
}
@@ -157,14 +138,6 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::move_shared_buffer(array other) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = other.strides();
array_desc_->flags = other.flags();
array_desc_->data_size = other.data_size();
array_desc_->data_ptr = other.array_desc_->data_ptr;
}
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
: shape(shape), dtype(dtype) {
std::tie(size, strides) = cum_prod(shape);
@@ -179,35 +152,9 @@ array::ArrayDesc::ArrayDesc(
dtype(dtype),
primitive(std::move(primitive)),
inputs(inputs) {
std::tie(size, strides) = cum_prod(this->shape);
std::tie(size, strides) = cum_prod(shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayDesc::ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs)
: shape(std::move(shape)),
dtype(dtype),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
: arr(arr), idx(idx) {
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
}
}

View File

@@ -127,7 +127,11 @@ class array {
using value_type = const array;
using reference = value_type;
explicit ArrayIterator(const array& arr, int idx = 0);
explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) {
if (arr.ndim() == 0) {
throw std::invalid_argument("Cannot iterate over 0-d array.");
}
}
reference operator*() const;
@@ -172,12 +176,6 @@ class array {
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs);
static std::vector<array> make_arrays(
const std::vector<std::vector<int>>& shapes,
const std::vector<Dtype>& dtypes,
@@ -221,11 +219,6 @@ class array {
return *(array_desc_->primitive);
};
/** A shared pointer to the array's primitive. */
std::shared_ptr<Primitive>& primitive_ptr() const {
return array_desc_->primitive;
};
/** Check if the array has an attached primitive or is a leaf node. */
bool has_primitive() const {
return array_desc_->primitive != nullptr;
@@ -240,11 +233,6 @@ class array {
return array_desc_->inputs;
}
/** True indicates the arrays buffer is safe to reuse */
bool is_donatable() const {
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
}
/** The array's siblings. */
const std::vector<array>& siblings() const {
return array_desc_->siblings;
@@ -267,11 +255,6 @@ class array {
return outputs;
};
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
uint16_t graph_depth() const {
return array_desc_->depth;
}
/** Detach the array from the graph. */
void detach();
@@ -292,12 +275,6 @@ class array {
return array_desc_->data->buffer;
};
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
@@ -338,8 +315,6 @@ class array {
void copy_shared_buffer(const array& other);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_;
}
@@ -382,9 +357,6 @@ class array {
// The arrays position in the output list
uint32_t position{0};
// The depth of the array in the graph.
uint16_t depth{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc(
@@ -392,18 +364,12 @@ class array {
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
explicit ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs);
};
// The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs
// and the list of array's inputs for the primitive.
// and a the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
};

View File

@@ -29,16 +29,12 @@ std::tuple<bool, size_t, array> check_transpose(const array& arr) {
}
}
inline void matmul_cblas_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
@@ -46,11 +42,6 @@ inline void matmul_cblas_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
@@ -59,34 +50,21 @@ inline void matmul_cblas_general(
M,
N,
K,
alpha, // alpha
1.0f, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
0.0f, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_cblas_general(a_pre, b_pre, out);
}
inline void matmul_bnns_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
// TODO: Update to utilize BNNS broadcasting
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
@@ -94,16 +72,11 @@ inline void matmul_bnns_general(
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha,
/* float beta = */ beta,
/* float alpha = */ 1.0,
/* float beta = */ 0.0,
/* bool transA = */ a_transposed,
/* bool transB = */ b_transposed,
/* bool quadratic = */ false,
@@ -184,12 +157,6 @@ inline void matmul_bnns_general(
BNNSFilterDestroy(bnns_filter);
}
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
// TODO: Update to utilize BNNS broadcasting
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_bnns_general(a_pre, b_pre, out);
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -199,16 +166,4 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
return matmul_bnns(inputs[0], inputs[1], out);
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy(c, out, ctype);
if (out.dtype() == float32) {
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
}
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <cmath>
@@ -35,8 +35,6 @@ DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
@@ -52,8 +50,6 @@ DEFAULT(LogicalNot)
DEFAULT(LogicalAnd)
DEFAULT(LogicalOr)
DEFAULT(LogAddExp)
DEFAULT(Maximum)
DEFAULT(Minimum)
DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
@@ -64,22 +60,30 @@ DEFAULT(Scatter)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)
DEFAULT_MULTI(QRF)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, size);
} else if (in.dtype() == int32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, size);
} else if (is_unsigned(in.dtype())) {
// No-op for unsigned types
out.copy_shared_buffer(in);
@@ -132,8 +136,12 @@ void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -144,8 +152,12 @@ void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -156,8 +168,12 @@ void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -168,8 +184,12 @@ void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -180,8 +200,12 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -192,8 +216,12 @@ void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -205,23 +233,30 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
if (in.flags().contiguous) {
auto allocfn = [&in, &out]() {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
};
// Use accelerate functions if possible
if (in.dtype() == float32 && out.dtype() == uint32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vfixu32(
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == float32 && out.dtype() == int32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
return;
} else if (in.dtype() == uint32 && out.dtype() == float32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vfltu32(
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
} else if (in.dtype() == int32 && out.dtype() == float32) {
set_unary_output_data(in, out);
allocfn();
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
return;
}
@@ -233,8 +268,12 @@ void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcosf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -245,8 +284,12 @@ void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcoshf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -334,8 +377,12 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
@@ -362,8 +409,12 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
switch (base_) {
case Base::e:
vvlogf(
@@ -387,8 +438,12 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) {
@@ -400,6 +455,47 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x > y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary(
a,
b,
out,
[](auto x, auto y) { return (x < y) ? x : y; },
UseDefaultBinaryOp(),
UseDefaultBinaryOp(),
[](const auto* a, const auto* b, auto* out, int n) {
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@@ -429,8 +525,13 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return -x; });
}
@@ -443,13 +544,7 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
int size = a.size();
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
} else {
eval(inputs, out);
@@ -491,8 +586,12 @@ void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -503,8 +602,12 @@ void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -515,8 +618,12 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return x * x; });
@@ -527,8 +634,12 @@ void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
if (recip_) {
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
} else {
@@ -583,8 +694,12 @@ void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);
@@ -595,8 +710,12 @@ void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanhf(out.data<float>(), in.data<float>(), &size);
} else {
eval(inputs, out);

View File

@@ -16,5 +16,4 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
)

View File

@@ -233,33 +233,14 @@ void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (is_floating_point(out.dtype())) {
binary(a, b, out, [](auto x, auto y) {
if (std::isnan(x)) {
return x;
}
return (x > y) ? x : y;
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
}
void Minimum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (is_floating_point(out.dtype())) {
binary(a, b, out, [](auto x, auto y) {
if (std::isnan(x)) {
return x;
}
return (x < y) ? x : y;
});
} else {
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
}
void Multiply::eval(const std::vector<array>& inputs, array& out) {

View File

@@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
@@ -39,83 +40,29 @@ void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt,
bool donate_with_move = false) {
BinaryOpType bopt) {
switch (bopt) {
case ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case ScalarVector:
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
}
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
break;
case VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
break;
case VectorVector:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
break;
case General:
if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (
b.is_donatable() && b.flags().row_contiguous &&
b.itemsize() == out.itemsize() && b.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
break;
}
}

View File

@@ -289,16 +289,11 @@ void copy(const array& src, array& dst, CopyType ctype) {
// Allocate the output
switch (ctype) {
case CopyType::Vector:
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
dst.copy_shared_buffer(src);
} else {
auto size = src.data_size();
dst.set_data(
allocator::malloc_or_wait(size * dst.itemsize()),
size,
src.strides(),
src.flags());
}
dst.set_data(
allocator::malloc_or_wait(src.data_size() * dst.itemsize()),
src.data_size(),
src.strides(),
src.flags());
break;
case CopyType::Scalar:
case CopyType::General:

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h>
@@ -6,8 +6,6 @@
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
@@ -47,8 +45,6 @@ DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(Remainder)
DEFAULT(Equal)
@@ -92,7 +88,6 @@ DEFAULT(Sinh)
DEFAULT(Slice)
DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
DEFAULT(Square)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
@@ -101,16 +96,17 @@ DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)
DEFAULT_MULTI(QRF)
namespace {
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[Matmul::eval_cpu] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
inline void matmul_common_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
@@ -128,15 +124,9 @@ inline void matmul_common_general(
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
@@ -145,41 +135,16 @@ inline void matmul_common_general(
M,
N,
K,
alpha, // alpha
1.0f, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
0.0f, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}
} // namespace
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[Matmul::eval_cpu] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_common_general(inputs[0], inputs[1], out);
}
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy(c, out, ctype);
return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_);
}
} // namespace mlx::core

View File

@@ -232,38 +232,22 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
}
}
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float>(in, out, [](auto x) { return std::erf(x); });
break;
case float16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float16_t>(in, out, [](auto x) {
return static_cast<float16_t>(std::erf(static_cast<float>(x)));
});
break;
case bfloat16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<bfloat16_t>(in, out, [](auto x) {
return static_cast<bfloat16_t>(std::erf(static_cast<float>(x)));
});
@@ -280,14 +264,17 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0];
switch (out.dtype()) {
case float32:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float>(in, out, [](auto x) { return erfinv(x); });
break;
case float16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float16_t>(in, out, [](auto x) {
return static_cast<float16_t>(erfinv(static_cast<float>(x)));
});
break;
case bfloat16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<bfloat16_t>(in, out, [](auto x) {
return static_cast<bfloat16_t>(erfinv(static_cast<float>(x)));
});
@@ -601,58 +588,6 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
}
void Split::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto compute_new_flags = [](const auto& shape,
const auto& strides,
size_t in_data_size,
auto flags) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in_data_size) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
return std::pair<decltype(flags), size_t>{flags, data_size};
};
std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < indices.size(); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
outputs[i].copy_shared_buffer(
in, in.strides(), new_flags, data_size, offset);
}
}
void Square::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];

View File

@@ -1,153 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/lapack.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
template <typename T>
struct lpack;
template <>
struct lpack<float> {
static void xgeqrf(
const int* m,
const int* n,
float* a,
const int* lda,
float* tau,
float* work,
const int* lwork,
int* info) {
sgeqrf_(m, n, a, lda, tau, work, lwork, info);
}
static void xorgqr(
const int* m,
const int* n,
const int* k,
float* a,
const int* lda,
const float* tau,
float* work,
const int* lwork,
int* info) {
sorgqr_(m, n, k, a, lda, tau, work, lwork, info);
}
};
template <typename T>
void qrf_impl(const array& a, array& q, array& r) {
const int M = a.shape(-2);
const int N = a.shape(-1);
const int lda = std::max(M, N);
size_t num_matrices = a.size() / (M * N);
int num_reflectors = std::min(M, N);
auto tau =
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
// Copy A to inplace input and make it col-contiguous
array in(a.shape(), float32, nullptr, {});
auto flags = in.flags();
// Copy the input to be column contiguous
flags.col_contiguous = num_matrices == 1;
flags.row_contiguous = false;
std::vector<size_t> strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral);
T optimal_work;
int lwork = -1;
int info;
// Compute workspace size
lpack<T>::xgeqrf(
&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
// Update workspace size
lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Solve
lpack<T>::xgeqrf(
&M,
&N,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
allocator::free(work);
r.set_data(allocator::malloc_or_wait(r.nbytes()));
copy_inplace(in, r, CopyType::General);
for (int i = 0; i < num_matrices; ++i) {
// Zero lower triangle
for (int j = 0; j < r.shape(-2); ++j) {
for (int k = 0; k < j; ++k) {
r.data<T>()[i * N * M + j * N + k] = 0;
}
}
}
// Get work size
lwork = -1;
lpack<T>::xorgqr(
&M,
&N,
&num_reflectors,
nullptr,
&lda,
nullptr,
&optimal_work,
&lwork,
&info);
lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
lpack<T>::xorgqr(
&M,
&N,
&num_reflectors,
in.data<float>() + M * N * i,
&lda,
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
static_cast<T*>(work.raw_ptr()),
&lwork,
&info);
}
q.set_data(allocator::malloc_or_wait(q.nbytes()));
copy_inplace(in, q, CopyType::General);
// Cleanup
allocator::free(work);
allocator::free(tau);
}
void QRF::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[QRF::eval] only supports float32.");
}
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
}
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/backend/metal/copy.h"
#include "mlx/primitives.h"
@@ -118,12 +119,6 @@ void _qmm_dispatch_typed(
switch (bits) {
case 2: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
@@ -140,12 +135,6 @@ void _qmm_dispatch_typed(
}
case 4: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
@@ -162,12 +151,6 @@ void _qmm_dispatch_typed(
}
case 8: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);

View File

@@ -56,32 +56,23 @@ struct SignOp {
struct RoundOp {
template <typename T>
T operator()(T x) {
return std::rint(x);
return std::round(x);
}
complex64_t operator()(complex64_t x) {
return {std::rint(x.real()), std::rint(x.imag())};
return {std::round(x.real()), std::round(x.imag())};
}
};
void set_unary_output_data(const array& in, array& out) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
}
template <typename T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
if (a.flags().contiguous) {
set_unary_output_data(a, out);
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
T* dst = out.data<T>();
for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]);

View File

@@ -23,16 +23,6 @@ void* Buffer::raw_ptr() {
namespace metal {
static bool cache_enabled_ = true;
bool cache_enabled() {
return cache_enabled_;
}
void set_cache_enabled(bool enabled) {
cache_enabled_ = enabled;
}
namespace {
BufferCache::BufferCache(MTL::Device* device)
@@ -206,11 +196,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (cache_enabled()) {
buffer_cache_.recycle_to_cache(buf);
} else {
buf->release();
}
buffer_cache_.recycle_to_cache(buf);
}
MetalAllocator& allocator() {

View File

@@ -2,6 +2,7 @@
#include <algorithm>
#include <cassert>
#include <iostream>
#include <numeric>
#include <sstream>
@@ -69,7 +70,7 @@ void explicit_gemm_conv_1D_gpu(
// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
return steel_matmul(
mlx_matmul(
s,
d,
/*a = */ in_strided,
@@ -261,7 +262,7 @@ void explicit_gemm_conv_2D_gpu(
// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
return steel_matmul(
mlx_matmul(
s,
d,
/*a = */ in_strided,
@@ -410,7 +411,7 @@ void winograd_conv_2D_gpu(
copies_w.push_back(out_wg);
{
std::vector<array> empty_copies;
steel_matmul(
mlx_matmul(
s,
d,
/*a = */ inp_wg,

View File

@@ -12,15 +12,11 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
@@ -71,8 +67,7 @@ void copy_gpu_inplace(
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-24 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <dlfcn.h>
#include <cstdlib>
@@ -242,127 +242,37 @@ void Device::register_library(
}
}
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
MTL::ComputePipelineState* Device::get_kernel(
const std::string& name,
const std::string& lib_name /* = "mlx" */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
return it->second;
}
// Prepare new kernel
// Search for cached metal lib
MTL::Library* mtl_lib;
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name);
mtl_lib = library_map_[lib_name];
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const std::string& source_string) {
auto pool = new_scoped_memory_pool();
auto ns_code =
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build metal library from source"
<< "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
auto pool = new_scoped_memory_pool();
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(desc, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build stitched metal library"
<< "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return mtl_lib;
}
MTL::Function* Device::get_function_(
const std::string& name,
MTL::Library* mtl_lib) {
// Pull kernel from library
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
auto mtl_function = mtl_lib->newFunction(ns_name);
return mtl_function;
}
MTL::Function* Device::get_function_(
const std::string& name,
const std::string& specialized_name,
const MTLFCList& func_consts,
MTL::Library* mtl_lib) {
if (func_consts.empty() && (specialized_name == name)) {
return get_function_(name, mtl_lib);
}
// Prepare function constants
auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init();
for (auto [value, type, index] : func_consts) {
mtl_func_consts->setConstantValue(value, type, index);
}
// Prepare function desc
auto desc = MTL::FunctionDescriptor::functionDescriptor();
desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding));
desc->setSpecializedName(
NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding));
desc->setConstantValues(mtl_func_consts);
// Pull kernel from library
NS::Error* error = nullptr;
auto mtl_function = mtl_lib->newFunction(desc, &error);
// Throw error if unable to build metal function
if (!mtl_function) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load function " << name << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
mtl_func_consts->release();
desc->release();
return mtl_function;
}
MTL::ComputePipelineState* Device::get_kernel_(
const std::string& name,
const MTL::Function* mtl_function) {
// Compile kernel to compute pipeline
NS::Error* error = nullptr;
MTL::ComputePipelineState* kernel;
if (mtl_function) {
kernel = device_->newComputePipelineState(mtl_function, &error);
mtl_function->release();
}
// Throw error if unable to compile metal function
if (!mtl_function || !kernel) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
@@ -372,170 +282,11 @@ MTL::ComputePipelineState* Device::get_kernel_(
throw std::runtime_error(msg.str());
}
return kernel;
}
MTL::ComputePipelineState* Device::get_kernel_(
const std::string& name,
const MTL::Function* mtl_function,
const MTL::LinkedFunctions* linked_functions) {
// Check inputs
if (!linked_functions) {
return get_kernel_(name, mtl_function);
}
if (!mtl_function) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
throw std::runtime_error(msg.str());
}
// Prepare compute pipeline state descriptor
auto desc = MTL::ComputePipelineDescriptor::alloc()->init();
desc->setComputeFunction(mtl_function);
desc->setLinkedFunctions(linked_functions);
// Compile kernel to compute pipeline
NS::Error* error = nullptr;
auto kernel = device_->newComputePipelineState(
desc, MTL::PipelineOptionNone, nullptr, &error);
// Throw error if unable to compile metal function
if (!kernel) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return kernel;
}
MTL::Library* Device::get_library(
const std::string& name,
const std::string& source,
bool cache /* = true */) {
if (cache) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(source);
if (cache) {
library_map_.insert({name, mtl_lib});
}
return mtl_lib;
}
MTL::Library* Device::get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache /* = true */) {
if (cache) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(desc);
if (cache) {
library_map_.insert({name, mtl_lib});
}
return mtl_lib;
}
MTL::Function* Device::get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
}
MTL::Function* Device::get_function(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
return get_function(base_name, mtl_lib, specialized_name, func_consts);
}
MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) {
return nullptr;
}
auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
std::vector<NS::Object*> objs(funcs.size());
for (int i = 0; i < funcs.size(); i++) {
objs[i] = funcs[i];
}
NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size());
lfuncs->setPrivateFunctions(funcs_arr);
return lfuncs;
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
const auto& kname = hash_name.empty() ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
// Pull kernel from library
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
kernel_map_.insert({kname, kernel});
kernel_map_.insert({name, kernel});
return kernel;
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
// Look for cached kernel
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
}
Device& device(mlx::core::Device) {
static Device metal_device;
return metal_device;

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-24 Apple Inc.
// Copyright © 2023 Apple Inc.
#pragma once
@@ -31,9 +31,6 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
return mtllib_path;
}
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
class Device {
public:
Device();
@@ -62,71 +59,14 @@ class Device {
const std::function<std::string(const std::string&)>& lib_path_func =
get_colocated_mtllib_path);
MTL::Library* get_library(
const std::string& name,
const std::string& source_string,
bool cache = true);
MTL::Library* get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache = true);
MTL::Function* get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::Function* get_function(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name = "",
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& hash_name = "",
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
const std::string& name,
const std::string& lib_name = "mlx");
MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
private:
MTL::Library* get_library_cache_(const std::string& name);
MTL::Library* get_library_(const std::string& source_string);
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
MTL::Function* get_function_(
const std::string& name,
const std::string& specialized_name,
const MTLFCList& func_consts,
MTL::Library* mtl_lib);
MTL::LinkedFunctions* get_linked_functions_(
const std::vector<MTL::Function*>& funcs);
MTL::ComputePipelineState* get_kernel_(
const std::string& name,
const MTL::Function* mtl_function);
MTL::ComputePipelineState* get_kernel_(
const std::string& name,
const MTL::Function* mtl_function,
const MTL::LinkedFunctions* linked_functions);
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;

View File

@@ -1,6 +1,5 @@
set(
HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
@@ -18,6 +17,7 @@ set(
"binary_two"
"conv"
"copy"
"gemm"
"gemv"
"quantized"
"random"
@@ -29,27 +29,26 @@ set(
"indexing"
)
function(build_kernel_base TARGET SRCFILE DEPS)
function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
set(HEADERS_PADDED ${HEADERS})
if(${KERNEL} STREQUAL "gemm")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/gemm.h)
endif()
if(${KERNEL} STREQUAL "conv")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/conv.h)
endif()
add_custom_command(
COMMAND xcrun -sdk macosx metal -Wall -Wextra
-fno-fast-math
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
-o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS}
OUTPUT ${TARGET}.air
COMMENT "Building ${TARGET}.air"
-o ${KERNEL}.air
DEPENDS ${SRCFILE} ${HEADERS_PADDED}
OUTPUT ${KERNEL}.air
COMMENT "Building ${KERNEL}.air"
VERBATIM
)
endfunction(build_kernel_base)
function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
set(HEADERS_PADDED ${HEADERS})
if(${KERNEL} STREQUAL "conv")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/conv.h)
endif()
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS_PADDED}")
endfunction(build_kernel)
foreach(KERNEL ${KERNELS})
@@ -57,15 +56,6 @@ foreach(KERNEL ${KERNELS})
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
endforeach()
file(GLOB_RECURSE STEEL_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.metal)
file(GLOB_RECURSE STEEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.h)
foreach(KERNEL ${STEEL_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib

View File

@@ -63,6 +63,18 @@ struct ArgMax {
}
};
bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
}
template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>(

View File

@@ -38,59 +38,49 @@ struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC T
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_and_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
METAL_FUNC void
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_min_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
METAL_FUNC void
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_max_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
METAL_FUNC void
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_add_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
METAL_FUNC void
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
}
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
METAL_FUNC void
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
T expected = mlx_atomic_load_explicit(object, offset);
while (!mlx_atomic_compare_exchange_weak_explicit(
object, &expected, val * expected, offset)) {
@@ -102,7 +92,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
device mlx_atomic<T>* object,
thread T* expected,
T val,
uint offset) {
int offset) {
return atomic_compare_exchange_weak_explicit(
&(object[offset].val),
expected,
@@ -116,7 +106,7 @@ template <>
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
device mlx_atomic<float>* object,
float val,
uint offset) {
int offset) {
float expected = mlx_atomic_load_explicit(object, offset);
while (val < expected) {
if (mlx_atomic_compare_exchange_weak_explicit(
@@ -131,7 +121,7 @@ template <>
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
device mlx_atomic<float>* object,
float val,
uint offset) {
int offset) {
float expected = mlx_atomic_load_explicit(object, offset);
while (val > expected) {
if (mlx_atomic_compare_exchange_weak_explicit(
@@ -158,7 +148,7 @@ union uint_or_packed {
template <typename T, typename Op>
struct mlx_atomic_update_helper {
uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
uint operator()(uint_or_packed<T> init, T update, int elem_offset) {
Op op;
init.val[elem_offset] = op(update, init.val[elem_offset]);
return init.bits;
@@ -169,9 +159,9 @@ template <typename T, typename Op>
METAL_FUNC void mlx_atomic_update_and_store(
device mlx_atomic<T>* object,
T update,
uint offset) {
uint pack_offset = offset / packing_size<T>;
uint elem_offset = offset % packing_size<T>;
int offset) {
int pack_offset = offset / packing_size<T>;
int elem_offset = offset % packing_size<T>;
mlx_atomic_update_helper<T, Op> helper;
uint_or_packed<T> expected;
@@ -252,9 +242,9 @@ struct __Min {
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC T
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
uint pack_offset = offset / sizeof(T);
uint elem_offset = offset % sizeof(T);
mlx_atomic_load_explicit(device mlx_atomic<T>* object, int offset) {
int pack_offset = offset / sizeof(T);
int elem_offset = offset % sizeof(T);
uint_or_packed<T> packed_val;
packed_val.bits =
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
@@ -263,17 +253,15 @@ mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_and_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
uint pack_offset = offset / packing_size<T>;
uint elem_offset = offset % packing_size<T>;
METAL_FUNC void
mlx_atomic_fetch_and_explicit(device mlx_atomic<T>* object, T val, int offset) {
int pack_offset = offset / packing_size<T>;
int elem_offset = offset % packing_size<T>;
uint_or_packed<T> identity;
identity.bits = __UINT32_MAX__;
identity.val[elem_offset] = val;
@@ -284,9 +272,9 @@ METAL_FUNC void mlx_atomic_fetch_and_explicit(
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
uint pack_offset = offset / packing_size<T>;
uint elem_offset = offset % packing_size<T>;
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, int offset) {
int pack_offset = offset / packing_size<T>;
int elem_offset = offset % packing_size<T>;
uint_or_packed<T> identity;
identity.bits = 0;
identity.val[elem_offset] = val;
@@ -296,34 +284,26 @@ mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_min_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
METAL_FUNC void
mlx_atomic_fetch_min_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_max_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
METAL_FUNC void
mlx_atomic_fetch_max_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_add_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
METAL_FUNC void
mlx_atomic_fetch_add_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
}
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
device mlx_atomic<T>* object,
T val,
uint offset) {
METAL_FUNC void
mlx_atomic_fetch_mul_explicit(device mlx_atomic<T>* object, T val, int offset) {
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
}
@@ -332,11 +312,11 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
device mlx_atomic<T>* object,
thread uint* expected,
uint val,
uint offset) {
int offset) {
return atomic_compare_exchange_weak_explicit(
&(object[offset].val),
expected,
val,
memory_order_relaxed,
memory_order_relaxed);
}
}

View File

@@ -58,9 +58,6 @@ struct LessEqual {
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
if (metal::isnan(x) || metal::isnan(y)) {
return metal::numeric_limits<T>::quiet_NaN();
}
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
@@ -70,48 +67,20 @@ struct LogAddExp {
};
struct Maximum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::max(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x > y ? x : y;
}
template <typename T> T operator()(T x, T y) { return metal::max(x, y); }
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x > y ? x : y;
return x >= y ? x : y;
}
};
struct Minimum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::min(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x < y ? x : y;
}
template <typename T> T operator()(T x, T y) { return metal::min(x, y); }
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x < y ? x : y;
return x <= y ? x : y;
}
};
@@ -420,4 +389,4 @@ instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)

View File

@@ -5,7 +5,7 @@
#include "mlx/backend/metal/kernels/conv_params.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/conv.h"
#include "mlx/backend/metal/kernels/gemm/conv.h"
using namespace metal;

View File

@@ -1,10 +1,9 @@
// Copyright © 2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
@@ -24,26 +23,26 @@ template <typename T,
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *C [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
const constant int &M [[buffer(3)]],
const constant int &N [[buffer(4)]],
const constant int &K [[buffer(5)]],
const constant int &batch_stride_a [[buffer(6)]],
const constant int &batch_stride_b [[buffer(7)]],
const constant int &batch_stride_c [[buffer(8)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
using gemm_kernel = GEMMKernel<T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += params->batch_stride_c * tid.z;
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
gemm_kernel::run(
A, B, C,
params,
As, Bs,
M, N, K,
batch_stride_a, batch_stride_b, batch_stride_c,
tgp_memory,
simd_lane_id, simd_group_id, tid, lid
);
}
@@ -53,12 +52,17 @@ template <typename T,
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *C [[buffer(2)]], \
const constant GEMMParams* params [[buffer(3)]], \
const constant int &M [[buffer(3)]], \
const constant int &N [[buffer(4)]], \
const constant int &K [[buffer(5)]], \
const constant int &batch_stride_a [[buffer(6)]], \
const constant int &batch_stride_b [[buffer(7)]], \
const constant int &batch_stride_c [[buffer(8)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
@@ -80,10 +84,10 @@ template <typename T,
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);
// TODO: Accumulation in different type

View File

@@ -0,0 +1,538 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#define MLX_MTL_CONST static constant constexpr const
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BROWS,
int BCOLS,
int BK,
int vec_size,
int tgp_size,
bool transpose,
bool ldK,
int tgp_padding = 0>
struct BlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS;
MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding;
MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
// Leading dimension for src
const int src_ld;
// Stride along reduction axis between blocks
const int tstride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tstride(
BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = src[i * src_ld + j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy;
// Iterate over rows of block
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
// Row is in bounds, we check against column
if ((bi + i) < src_tile_dim.y) {
// Use fast thread memory for bound checks
short tmp_idx[vec_size];
T tmp_val[vec_size];
// Make sure tmp_idx only contains valid indices
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
}
// Read all valid indices into tmp_val
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
}
// Zero out unneeded values
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
// Row is out of bounds, we just fill tgp memory with zeros
else {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tstride;
}
};
///////////////////////////////////////////////////////////////////////////////
// Transforms
///////////////////////////////////////////////////////////////////////////////
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
int tgp_padding_a = 0,
int tgp_padding_b = 0,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct BlockMMA {
// Warp tile size along M
MLX_MTL_CONST int TM = BM / (WM * 8);
// Warp tile size along N
MLX_MTL_CONST int TN = BN / (WN * 8);
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TN_stride = 8 * WN;
// Leading dimensions of threadgroup A, B blocks
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
// Strides of A, B along reduction axis
MLX_MTL_CONST short simd_stride_a =
transpose_a ? TM_stride : TM_stride * lda_tgp;
MLX_MTL_CONST short simd_stride_b =
transpose_b ? TN_stride * ldb_tgp : TN_stride;
// Jump between elements
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
// Offsets within threadgroup
const int tm;
const int tn;
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
short sm;
short sn;
/* Constructor */
METAL_FUNC BlockMMA(
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Iterate over BK in blocks of 8
#pragma clang loop unroll(full)
for (short kk = 0; kk < BK; kk += 8) {
short2 offset_a =
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
short2 offset_b =
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
As__ += simd_stride_a;
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
Bs__ += simd_stride_b;
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
simdgroup_multiply_accumulate(
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device T* C, const int ldc) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
METAL_FUNC void
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
}
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct GEMMKernel {
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
MLX_MTL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
MLX_MTL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
MLX_MTL_CONST short tgp_size = WM * WN * 32;
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
using loader_a_t = BlockLoader<
T,
BM,
BK,
BK,
vec_size,
tgp_size,
transpose_a,
true,
tgp_padding_a>;
using loader_b_t = BlockLoader<
T,
BK,
BN,
BK,
vec_size,
tgp_size,
transpose_b,
false,
tgp_padding_b>;
using mma_t = BlockMMA<
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
tgp_padding_a,
tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant int& M [[buffer(3)]],
const constant int& N [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& batch_stride_a [[buffer(6)]],
const constant int& batch_stride_b [[buffer(7)]],
const constant int& batch_stride_c [[buffer(8)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
// Adjust for batch
A += batch_stride_a * tid.z;
B += batch_stride_b * tid.z;
C += batch_stride_c * tid.z;
// Adjust for transpose
const int lda_dev = transpose_a ? M : K;
const int ldb_dev = transpose_b ? K : N;
// Find block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
A += transpose_a ? c_row : c_row * K;
B += transpose_b ? c_col * K : c_col;
C += c_row * N + c_col;
// Prepare threadgroup memory for loading
threadgroup T* As = tgp_memory;
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
// Prepare threadgroup loading operations
loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id);
loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
mma_t mma_op(simd_group_id, simd_lane_id);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned && K_aligned) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
mma_op.store_result(C, N);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN aligned, K unaligned loop
else if (MN_aligned && !K_aligned) {
// Main loop
int k = 0;
for (; k + BK <= K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
// Loop tail
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_a.load_safe(short2(K - k, BM));
loader_b.load_safe(short2(BN, K - k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
// Store results to device memory
mma_op.store_result(C, N);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MNK unaligned loop
else { // Loop over K - unaligned case
short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row));
if (src_tile_dims.y == BM && src_tile_dims.x == BN) {
int k = 0;
for (; k + BK <= K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
if (k < K) {
loader_a.load_safe(short2(K - k, BM));
loader_b.load_safe(short2(BN, K - k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
mma_op.store_result(C, N);
return;
} else {
int k = 0;
for (; k + BK <= K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_safe(short2(BK, src_tile_dims.y));
loader_b.load_safe(short2(src_tile_dims.x, BK));
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
if (k < K) {
loader_a.load_safe(short2(K - k, src_tile_dims.y));
loader_b.load_safe(short2(src_tile_dims.x, K - k));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
threadgroup_barrier(mem_flags::mem_none);
mma_op.store_result_safe(C, N, src_tile_dims);
return;
}
}
}
};

View File

@@ -121,18 +121,8 @@ struct GEMVKernel {
for(int tm = 0; tm < TM; tm++) {
// Load for the row
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
inter[tn] = mat[tm * in_vec_size + col_idx];
}
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn];
}
// Accumulate results

View File

@@ -173,7 +173,8 @@ template <typename T, typename IdxT, typename Op, int NIDX>
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
op.atomic_update(out + out_idx + out_offset, updates[upd_idx]);
}
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \

View File

@@ -5,10 +5,9 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
@@ -142,11 +141,10 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Adjust positions
const int out_vec_size_w = out_vec_size / el_per_int;
const int out_vec_size_g = out_vec_size / group_size;
int out_col_start = tid.y * (BN * el_per_int);
int out_col = out_col_start + simd_gid * el_per_int;
int out_col = (tid.y * BN + simd_gid) * el_per_int;
w += out_col / el_per_int;
scales += out_col_start / group_size;
biases += out_col_start / group_size;
scales += out_col / group_size;
biases += out_col / group_size;
x += tid.z * in_vec_size;
y += tid.z * out_vec_size + out_col;
@@ -156,22 +154,23 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=BM) {
int offset_lid = simd_lid + i;
int offset_gid = simd_gid + i;
bool thread_in_bounds = offset_lid < in_vec_size;
bool group_in_bounds = offset_gid < in_vec_size;
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
x_block[simd_lid] = x[simd_lid + i];
}
// Load the scales and biases to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_lid < groups_per_block && group_in_bounds) {
scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
if (simd_gid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j];
}
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -181,7 +180,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
// Load the matrix elements
w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
w_local = w[(i + simd_lid) * out_vec_size_w];
// Do all the work.
#pragma clang loop unroll(full)
@@ -207,7 +206,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
}
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits, const bool aligned_N>
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
[[kernel]] void qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
@@ -237,9 +236,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, true>;
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
threadgroup T scales_block[BN * groups_per_block];
threadgroup T biases_block[BN * groups_per_block];
@@ -259,7 +257,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
@@ -295,48 +292,21 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Load the w tile
{
if (!aligned_N && num_outs < BN) {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
if (y_col + offset_col < N) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
} else {
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = 0;
}
}
}
} else {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
}
}
@@ -354,8 +324,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM || num_outs < BN) {
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
if (num_els < BM) {
mma_op.store_result_safe(y, N, short2(BN, num_els));
} else {
mma_op.store_result(y, N);
}
@@ -391,8 +361,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
constexpr int w_els_per_thread = (BK * BN / el_per_int) / (SIMD_SIZE * WM * WN);
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK, BN>;
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, false>;
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
threadgroup T scales_block[BK * groups_per_block];
threadgroup T biases_block[BK * groups_per_block];
@@ -447,48 +417,21 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Load the w tile
{
if (k + BK >= K) {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
if (y_row + offset_row < K) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
} else {
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = 0;
}
}
}
} else {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
}
}
@@ -540,9 +483,6 @@ instantiate_qmv_types(128, 8)
instantiate_qmv_types( 64, 2)
instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8)
instantiate_qmv_types( 32, 2)
instantiate_qmv_types( 32, 4)
instantiate_qmv_types( 32, 8)
#define instantiate_qvm(name, itype, group_size, bits) \
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
@@ -570,13 +510,10 @@ instantiate_qvm_types(128, 8)
instantiate_qvm_types( 64, 2)
instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8)
instantiate_qvm_types( 32, 2)
instantiate_qvm_types( 32, 4)
instantiate_qvm_types( 32, 8)
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits, aligned_N>( \
#define instantiate_qmm_t(name, itype, group_size, bits) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
@@ -591,12 +528,9 @@ instantiate_qvm_types( 32, 8)
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmm_t_types(group_size, bits) \
instantiate_qmm_t(float32, float, group_size, bits, false) \
instantiate_qmm_t(float16, half, group_size, bits, false) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
instantiate_qmm_t(float32, float, group_size, bits, true) \
instantiate_qmm_t(float16, half, group_size, bits, true) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true)
instantiate_qmm_t(float32, float, group_size, bits) \
instantiate_qmm_t(float16, half, group_size, bits) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits)
instantiate_qmm_t_types(128, 2)
instantiate_qmm_t_types(128, 4)
@@ -604,9 +538,6 @@ instantiate_qmm_t_types(128, 8)
instantiate_qmm_t_types( 64, 2)
instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8)
instantiate_qmm_t_types( 32, 2)
instantiate_qmm_t_types( 32, 4)
instantiate_qmm_t_types( 32, 8)
#define instantiate_qmm_n(name, itype, group_size, bits) \
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
@@ -635,6 +566,3 @@ instantiate_qmm_n_types(128, 8)
instantiate_qmm_n_types( 64, 2)
instantiate_qmm_n_types( 64, 4)
instantiate_qmm_n_types( 64, 8)
instantiate_qmm_n_types( 32, 2)
instantiate_qmm_n_types( 32, 4)
instantiate_qmm_n_types( 32, 8)

View File

@@ -16,7 +16,7 @@ union bool4_or_uint {
struct None {
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_store_explicit(out, val, offset);
}
};
@@ -41,7 +41,7 @@ struct And {
}
}
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
if (!val) {
mlx_atomic_store_explicit(out, val, offset);
}
@@ -68,8 +68,8 @@ struct Or {
void atomic_update(
device mlx_atomic<unsigned int>* out,
bool val,
uint elem_idx,
uint offset = 0) {
int elem_idx,
int offset = 0) {
if (val) {
bool4_or_uint update;
update.b = {false, false, false, false};
@@ -78,7 +78,7 @@ struct Or {
}
}
void atomic_update(device mlx_atomic<bool>* out, bool val, uint offset = 0) {
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
if (val) {
mlx_atomic_store_explicit(out, val, offset);
}
@@ -105,7 +105,7 @@ struct Sum {
static constexpr constant U init = U(0);
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_fetch_add_explicit(out, val, offset);
}
@@ -125,7 +125,7 @@ struct Prod {
static constexpr constant U init = U(1);
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_fetch_mul_explicit(out, val, offset);
}
@@ -145,7 +145,7 @@ struct Min {
static constexpr constant U init = Limits<U>::max;
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_fetch_min_explicit(out, val, offset);
}
@@ -165,7 +165,7 @@ struct Max {
static constexpr constant U init = Limits<U>::min;
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, uint offset = 0) {
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_fetch_max_explicit(out, val, offset);
}

View File

@@ -24,59 +24,11 @@ template <typename T, typename Op>
device otype *out [[buffer(1)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// All reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
inline U per_thread_all_reduce(
const device T *in,
const device size_t& in_size,
uint gid,
uint grid_size) {
Op op;
U total_val = Op::init;
if (gid * N_READS < in_size) {
in += gid * N_READS;
int r = 0;
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for(int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for(int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for(int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
}
return total_val;
}
// NB: This kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T *in [[buffer(0)]],
@@ -88,18 +40,53 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// NB: this kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
Op op;
threadgroup U local_vals[simd_size];
U total_val = Op::init;
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
in += gid * N_READS;
int r = 0;
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for(int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for(int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for(int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
// Reduction within simd group
total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
// Reduction within thread group
threadgroup_barrier(mem_flags::mem_threadgroup);
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
@@ -111,46 +98,6 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Write simd group reduction results to local memory
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction of simdgroup reduction results within threadgroup.
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Reduction across threadgroups
if (lid == 0) {
out[thread_group_id] = total_val;
}
}
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \
@@ -164,80 +111,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] \
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Row atomics
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
inline U per_thread_row_reduce(
const device T *in,
const constant size_t& reduction_size,
const constant size_t& out_size,
const constant int* shape,
const constant size_t* strides,
const constant int& ndim,
uint lsize_x,
uint lid_x,
uint2 tid) {
Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid_x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize_x * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
return total_val;
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general(
const device T *in [[buffer(0)]],
@@ -255,9 +133,46 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid.x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize.x * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
total_val = op.simd_reduce(total_val);
@@ -279,53 +194,6 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
// Reduction within simd group - simd_add isn't supported for int64 types
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if thread group has multiple simd groups
if(ceildiv(reduction_size, N_READS) > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
}
// Write row reduce output for threadgroup with 1st thread in thread group
if (lid.x == 0) {
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
}
}
#define instantiate_row_reduce_general(name, itype, otype, op) \
template [[host_name("row_reduce_general_" #name)]] \
[[kernel]] void row_reduce_general<itype, otype, op>( \
@@ -343,59 +211,52 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Column reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
inline U _contiguous_strided_reduce(
const device T *in,
threadgroup U *local_data,
uint in_idx,
uint reduction_size,
uint reduction_stride,
uint2 tid,
uint2 lid,
inline void _contiguous_strided_reduce(
const device T *in,
device mlx_atomic<U> *out,
threadgroup U *local_data,
uint in_idx,
uint out_idx,
uint reduction_size,
uint reduction_stride,
uint2 tid,
uint2 lid,
uint2 lsize) {
Op op;
U total_val = Op::init;
T local_vals[N_READS];
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
uint offset = base_offset + r;
total_val = op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
for(uint r = 0; r < N_READS; r++) {
uint offset = base_offset + r;
offset = offset < reduction_size ? offset : reduction_size - 1;
local_vals[r] = in[in_idx + offset * reduction_stride];
}
local_data[lsize.y * lid.x + lid.y] = total_val;
U total_val = Op::init;
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
total_val = op(static_cast<U>(total_val), local_vals[r]);
}
local_data[lsize.y * lid.x + lid.y] = total_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
U val = Op::init;
if(lid.y == 0) {
// Perform reduction across columns in thread group
for(uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]);
}
}
U val = op.init;
return val;
for(uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]);
}
op.atomic_update(out, val, out_idx);
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
@@ -404,13 +265,13 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
@@ -420,66 +281,18 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
ndim
);
Op op;
if(out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
op.atomic_update(out, val, out_idx);
}
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
if(out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
uint tgsize_y = ceildiv(gsize.y, lsize.y);
uint tgsize_z = ceildiv(gsize.z, lsize.z);
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
}
_contiguous_strided_reduce<T, U, Op, N_READS>(
in,
out,
local_data,
in_idx,
out_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
}
}
@@ -499,23 +312,6 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
@@ -526,15 +322,6 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_general(name, itype, otype, op)
#define instantiate_reduce_no_atomics(name, itype, otype, op) \
instantiate_all_reduce_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_col_reduce_general_no_atomics(name, itype, otype, op)
#define instantiate_same_reduce_no_atomics(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce_no_atomics(name ##tname, type, type, op<type>)
#define instantiate_same_reduce(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce(name ##tname, type, type, op<type>)
@@ -566,9 +353,6 @@ instantiate_same_reduce(sum, int32, int32_t, Sum)
instantiate_same_reduce(sum, float16, half, Sum)
instantiate_same_reduce(sum, float32, float, Sum)
instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum)
instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum)
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
@@ -578,9 +362,6 @@ instantiate_same_reduce(prod, int32, int32_t, Prod)
instantiate_same_reduce(prod, float16, half, Prod)
instantiate_same_reduce(prod, float32, float, Prod)
instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod)
instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod)
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
@@ -600,9 +381,6 @@ instantiate_same_reduce(min_, int32, int32_t, Min)
instantiate_same_reduce(min_, float16, half, Min)
instantiate_same_reduce(min_, float32, float, Min)
instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min)
instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min)
instantiate_same_reduce(max_, uint8, uint8_t, Max)
instantiate_same_reduce(max_, uint16, uint16_t, Max)
instantiate_same_reduce(max_, uint32, uint32_t, Max)
@@ -612,8 +390,5 @@ instantiate_same_reduce(max_, int32, int32_t, Max)
instantiate_same_reduce(max_, float16, half, Max)
instantiate_same_reduce(max_, float32, float, Max)
instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max)
instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max)
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)

View File

@@ -1,312 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel class
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct GEMMKernel {
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
STEEL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
STEEL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
STEEL_CONST short tgp_size = WM * WN * 32;
using loader_a_t = BlockLoader<
T,
transpose_a ? BK : BM,
transpose_a ? BM : BK,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
!transpose_a,
tgp_size>;
using loader_b_t = BlockLoader<
T,
transpose_b ? BN : BK,
transpose_b ? BK : BN,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
transpose_b,
tgp_size>;
using mma_t = BlockMMA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
thread mma_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
thread const short& lbk,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
if (!M_aligned) {
short2 tile_dims_A =
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
loader_a.set_mask(tile_dims_A, mask_A);
}
if (!N_aligned) {
short2 tile_dims_B =
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
loader_b.set_mask(tile_dims_B, mask_B);
}
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(mask_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(mask_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned_) {
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.set_mask(tile_dims_A_last, mask_A);
loader_b.set_mask(tile_dims_B_last, mask_B);
loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* C [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
loader_a.set_mask(tile_dims_A, mask_A);
loader_b.set_mask(tile_dims_B, mask_B);
loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(C, params->ldc);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_loop<true, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result(C, params->ldc);
return;
} else if (tgp_bn == BN) {
gemm_loop<false, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
} else {
gemm_loop<false, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -1,260 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = float,
typename Epilogue = TransformAdd<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
const device T *C [[buffer(2)]],
device T *D [[buffer(3)]],
const constant GEMMAddMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
using gemm_kernel =
GEMMKernel<T, T, BM, BN, BK, WM, WN,
transpose_a, transpose_b,
MN_aligned, K_aligned,
AccumType, Epilogue>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += params->batch_stride_c * tid.z;
D += params->batch_stride_d * tid.z;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col * params->fdc;
D += c_row * params->ldd + c_col;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
const Epilogue epilogue_op(params->alpha, params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
loader_a.set_mask(tile_dims_A, mask_A);
loader_b.set_mask(tile_dims_B, mask_B);
loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, K_aligned>{});
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
return;
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
const device itype *C [[buffer(2)]], \
device itype *D [[buffer(3)]], \
const constant GEMMAddMMParams* params [[buffer(4)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);

View File

@@ -1,280 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device U *C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
(void)lid;
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
const int tid_x = tid.x;
const int tid_y = tid.y;
const int tid_z = tid.z;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int k_start = params->split_k_partition_size * tid_z;
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K % BK;
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if ((tid_z + 1) == (params->split_k_partitions)) {
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
if(!K_aligned || gemm_k_iter_remaining > 0)
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iter_remaining,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
}
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
mma_op.store_result(C, params->ldc);
} else {
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device otype *C [[buffer(2)]], \
const constant GEMMSpiltKParams* params [[buffer(3)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
///////////////////////////////////////////////////////////////////////////////
// Split k accumulation kernel
///////////////////////////////////////////////////////////////////////////////
template <typename AccT,
typename OutT,
typename Epilogue = TransformNone<OutT, AccT>>
[[kernel]] void gemm_splitk_accum(
const device AccT *C_split [[buffer(0)]],
device OutT *D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
int offset = 0;
AccT out = 0;
for(int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
D[0] = Epilogue::apply(out);
}
template <typename AccT,
typename OutT,
typename Epilogue = TransformAxpby<OutT, AccT>>
[[kernel]] void gemm_splitk_accum_axpby(
const device AccT *C_split [[buffer(0)]],
device OutT *D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device OutT *C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
C += gid.x * fdc + gid.y * ldc;
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
int offset = 0;
AccT out = 0;
for(int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
Epilogue op(alpha, beta);
D[0] = op.apply(out, *C);
}
#define instantiate_accum(oname, otype, aname, atype) \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
[[kernel]] void gemm_splitk_accum<atype, otype>( \
const device atype *C_split [[buffer(0)]], \
device otype *D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
uint2 gid [[thread_position_in_grid]]); \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
const device atype *C_split [[buffer(0)]], \
device otype *D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
const device otype *C [[buffer(5)]], \
const constant int& ldc [[buffer(6)]], \
const constant int& fdc [[buffer(7)]], \
const constant float& alpha [[buffer(8)]], \
const constant float& beta [[buffer(9)]], \
uint2 gid [[thread_position_in_grid]]);
instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float);
instantiate_accum(float32, float, float32, float);

View File

@@ -1,160 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoader {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void set_mask(
thread const short2& src_tile_dims,
thread bool mask[n_rows][vec_size]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
mask[i][j] =
((bi + i) < src_tile_dims.y) && ((bj + j) < src_tile_dims.x);
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(const thread bool mask[n_rows][vec_size]) const {
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0, ii = 0; i < BROWS; i += TROWS, ii++) {
simdgroup_barrier(mem_flags::mem_none);
// Use fast thread memory for bound checks
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(mask[ii][j] ? i * src_ld + j : 0)];
}
simdgroup_barrier(mem_flags::mem_none);
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = mask[ii][j] ? tmp_val[j] : T(0);
}
simdgroup_barrier(mem_flags::mem_none);
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
} // namespace steel
} // namespace mlx

View File

@@ -1,264 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMA {
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = 8 * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Strides of A, B along reduction axis
STEEL_CONST short simd_stride_a = {
transpose_a ? TM_stride : TM_stride * lda_tgp};
STEEL_CONST short simd_stride_b = {
transpose_b ? TN_stride * ldb_tgp : TN_stride};
// Jump between elements
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
// Offsets within threadgroup
const short tm;
const short tn;
short sm;
short sn;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
// Determine thread position in simdgroup matrix
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Determine thread and simdgroup offset
As_offset =
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
Bs_offset =
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of 8
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += 8) {
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] =
static_cast<AccumType>(As[i * simd_stride_a + 0]);
Asimd[i].thread_elements()[1] =
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] =
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
Bsimd[j].thread_elements()[1] =
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
short j_serp = (i % 2) ? (TN - 1 - j) : j;
simdgroup_multiply_accumulate(
results[i * TN + j_serp],
Asimd[i],
Bsimd[j_serp],
results[i * TN + j_serp]);
}
}
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* C, const int ldc) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
// Write out C
C[offset] = outs[0];
C[offset + 1] = outs[1];
}
}
}
METAL_FUNC void
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
U outs[2] = {
epilogue_op.apply(accum[0], C[offset_c]),
epilogue_op.apply(accum[1], C[offset_c + fdc])};
// Write out D
D[offset_d] = outs[0];
D[offset_d + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
}
}
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -1,79 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
///////////////////////////////////////////////////////////////////////////////
// GEMM param classes
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
struct GEMMParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_c;
const int swizzle_log;
const int gemm_k_iterations_aligned;
};
struct GEMMSpiltKParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int tiles_n;
const int tiles_m;
const int split_k_partitions;
const int split_k_partition_stride;
const int split_k_partition_size;
const int gemm_k_iterations_aligned;
};
struct GEMMAddMMParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int ldd;
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_c;
const int batch_stride_d;
const int swizzle_log;
const int gemm_k_iterations_aligned;
const float alpha;
const float beta;
const int fdc;
};
} // namespace steel
} // namespace mlx

View File

@@ -1,63 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
///////////////////////////////////////////////////////////////////////////////
// Transforms and Epilogues
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT) {
return static_cast<OutT>(x);
}
};
template <typename OutT, typename InT>
struct TransformAdd {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(InT x, OutT c) {
return static_cast<OutT>(x) + c;
}
};
template <typename OutT, typename InT>
struct TransformAxpby {
const float alpha;
const float beta;
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
struct BlockSwizzle {
static METAL_FUNC int2
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
const int tid_x = (tid.x) >> swizzle_log;
const int tid_y =
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
return int2(tid_x, tid_y);
}
};
} // namespace steel
} // namespace mlx

View File

@@ -1,5 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/gemm/params.h"

View File

@@ -1,9 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/host.h"
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")

View File

@@ -134,8 +134,8 @@ struct Negative {
};
struct Round {
template <typename T> T operator()(T x) { return metal::rint(x); };
template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; };
template <typename T> T operator()(T x) { return metal::round(x); };
template <> complex64_t operator()(complex64_t x) { return {metal::round(x.real), metal::round(x.imag)}; };
};
struct Sigmoid {

View File

@@ -235,42 +235,12 @@ inline size_t ceildiv(size_t N, size_t M) {
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
inline float log1p(float x) {
float xp1 = 1.0f + x;
if (xp1 == Limits<float>::max) {
return Limits<float>::max;
}
if (xp1 == 1.0f) {
return x;
}
return x * (metal::log(xp1) / (xp1 - 1.0f));
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
}
inline bfloat16_t log1p(bfloat16_t x) {
float xp1 = 1.0f + static_cast<float>(x);
if (xp1 == Limits<float>::max) {
return Limits<bfloat16_t>::max;
}
if (xp1 == 1.0f) {
return x;
}
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
bfloat16_t ret =
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
return ret;
}
///////////////////////////////////////////////////////////////////////////////
// SIMD shuffle ops
///////////////////////////////////////////////////////////////////////////////
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
@@ -8,7 +8,6 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/host.h"
#include "mlx/backend/metal/matmul.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
@@ -17,10 +16,6 @@
namespace mlx::core {
///////////////////////////////////////////////////////////////////////////////
// MPS Matmul fallback
///////////////////////////////////////////////////////////////////////////////
namespace {
bool use_mps() {
@@ -51,9 +46,7 @@ inline void mps_matmul(
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
float alpha = 1.0f,
float beta = 0.0f) {
std::vector<array>& copies) {
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
if (out.dtype() == float16) {
@@ -128,7 +121,7 @@ inline void mps_matmul(
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
auto kernel = MPS::MatrixMultiplication::alloc()->init(
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
auto command_buffer = d.get_command_buffer(s.index);
kernel->setBatchSize(batch_size_out);
@@ -169,7 +162,7 @@ inline void mps_matmul(
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
auto kernel = MPS::MatrixMultiplication::alloc()->init(
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
auto command_buffer = d.get_command_buffer(s.index);
for (int i = 0; i < batch_size_out; ++i) {
@@ -193,11 +186,7 @@ inline void mps_matmul(
} // namespace
///////////////////////////////////////////////////////////////////////////////
// Steel matmul fallback
///////////////////////////////////////////////////////////////////////////////
void steel_matmul(
void mlx_matmul(
const Stream& s,
metal::Device& d,
const array& a,
@@ -212,15 +201,6 @@ void steel_matmul(
bool transpose_a,
bool transpose_b,
std::vector<array>& copies) {
using namespace mlx::steel;
// Coalesce (B, M, K) X (K, N) to (B*M, K) X (K, N)
if (batch_size_out > 1 && !transpose_a &&
a.data_size() == batch_size_out * M * K && b.size() == K * N) {
M = M * batch_size_out;
batch_size_out = 1;
}
// Account for batch sizes and basic broadcasting
int batch_size_a = a.data_size() / (M * K);
int batch_size_b = b.data_size() / (K * N);
@@ -229,108 +209,11 @@ void steel_matmul(
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
int matrix_stride_out = M * N;
/////////////////////////////////////////////////////////////////////////////
// Split K specialization
int _tm = M / 16;
int _tn = N / 16;
int _tk = K / 16;
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
int bm = M < 40 ? 16 : 32;
int bn = N < 40 ? 16 : 32;
int bk = 16;
int wm = 2, wn = 2;
int split_k_partitions =
_tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16));
int split_k_partition_stride = M * N;
int gemm_k_iterations = (K / bk) / split_k_partitions;
int split_k_partition_size = gemm_k_iterations * bk;
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
copies.push_back(C_split);
std::ostringstream kname;
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch gemm kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
GEMMSpiltKParams params{
M,
N,
K,
lda,
ldb,
N,
tn,
tm,
split_k_partitions,
split_k_partition_stride,
split_k_partition_size,
gemm_k_iterations};
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel
{
auto c_split_buf =
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
const class MTL::Resource* const resources[1] = {c_split_buf};
compute_encoder->memoryBarrier(resources, 1);
auto kernel = d.get_kernel(
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split));
compute_encoder->setComputePipelineState(kernel);
// Set the arguments for the kernel
set_array_buffer(compute_encoder, C_split, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if ((size_t)batch_size_out * M * N >= 2ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
@@ -341,12 +224,10 @@ void steel_matmul(
}
}
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
kname << "gemm_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n')
<< "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
@@ -355,55 +236,34 @@ void steel_matmul(
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
M,
N,
K,
lda,
ldb,
N,
tn,
tm,
matrix_stride_a,
matrix_stride_b,
matrix_stride_out,
swizzle_log,
(K / bk)};
// Prepare launch grid params
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
// Launch only 1 kernel in the case of simple batching / broadcasting
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
(batch_size_a == batch_size_b ||
std::min(batch_size_a, batch_size_b) == 1)) {
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims =
MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, batch_size_out);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 3);
compute_encoder->setBytes(&M, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
compute_encoder->setBytes(&K, sizeof(int), 5);
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} else { // Otherwise launch kernels with set offsets
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
} else { // Other launch kernels with set offsets
for (int i = 0; i < batch_size_out; ++i) {
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
@@ -412,8 +272,13 @@ void steel_matmul(
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
compute_encoder->setBytes(&M, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
compute_encoder->setBytes(&K, sizeof(int), 5);
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
}
@@ -435,9 +300,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
@@ -466,9 +328,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto batch_size_out = out.size() / (M * N);
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
@@ -574,13 +433,10 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
/////////////////////////////////////////////////////////////////////////////
// Gemm specialization
d.end_encoding(s.index);
if (use_mps()) {
d.end_encoding(s.index);
return mps_matmul(
mps_matmul(
s,
d,
a,
@@ -595,9 +451,10 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
a_transposed,
b_transposed,
copies);
return;
}
return steel_matmul(
mlx_matmul(
s,
d,
a,
@@ -614,266 +471,4 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
copies);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
if (!is_floating_point(out.dtype())) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto& c_pre = inputs[2];
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
auto batch_size_out = out.size() / (M * N);
array c = c_pre;
int ldc = c.strides()[c.ndim() - 2];
int fdc = c.strides()[c.ndim() - 1];
int matrix_stride_c = c.ndim() <= 2 ? 0 : c.strides()[c.ndim() - 3];
int lda = a_cols;
int ldb = b_cols;
using namespace mlx::steel;
// Account for batch sizes and basic broadcasting
int batch_size_a = a.data_size() / (M * K);
int batch_size_b = b.data_size() / (K * N);
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
int matrix_stride_out = M * N;
int _tm = M / 16;
int _tn = N / 16;
int _tk = K / 16;
/////////////////////////////////////////////////////////////////////////////
// Split K specialization
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
int bm = M < 40 ? 16 : 32;
int bn = N < 40 ? 16 : 32;
int bk = 16;
int wm = 2, wn = 2;
int split_k_partitions =
_tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16));
int split_k_partition_stride = M * N;
int gemm_k_iterations = (K / bk) / split_k_partitions;
int split_k_partition_size = gemm_k_iterations * bk;
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
copies.push_back(C_split);
std::ostringstream kname;
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch gemm kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
GEMMSpiltKParams params{
M,
N,
K,
lda,
ldb,
N,
tn,
tm,
split_k_partitions,
split_k_partition_stride,
split_k_partition_size,
gemm_k_iterations};
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel
{
auto kernel = d.get_kernel(
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split) + "_axpby");
compute_encoder->setComputePipelineState(kernel);
// Set the arguments for the kernel
set_array_buffer(compute_encoder, C_split, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
compute_encoder->setBytes(&N, sizeof(int), 4);
set_array_buffer(compute_encoder, c, 5);
compute_encoder->setBytes(&ldc, sizeof(int), 6);
compute_encoder->setBytes(&fdc, sizeof(int), 7);
compute_encoder->setBytes(&alpha_, sizeof(float), 8);
compute_encoder->setBytes(&beta_, sizeof(float), 9);
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
/////////////////////////////////////////////////////////////////////////////
// Regular addmm dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
std::ostringstream kname;
kname << "steel_addmm_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"
<< ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby");
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
GEMMAddMMParams params{
M,
N,
K,
lda,
ldb,
ldc,
N,
tn,
tm,
matrix_stride_a,
matrix_stride_b,
matrix_stride_c,
matrix_stride_out,
swizzle_log,
(K / bk),
alpha_,
beta_,
fdc};
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
// Launch only 1 kernel in the case of simple batching / broadcasting
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
(batch_size_a == batch_size_b ||
std::min(batch_size_a, batch_size_b) == 1)) {
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, c, 2);
set_array_buffer(compute_encoder, out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 4);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} else { // Otherwise launch kernels with set offsets
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
for (int i = 0; i < batch_size_out; ++i) {
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
auto c_off = elem_to_loc(M * N * i, c.shape(), c.strides());
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto c_buf = static_cast<const MTL::Buffer*>(c.buffer().ptr());
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
compute_encoder->setBuffer(c_buf, c_off * c.itemsize(), 2);
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 3);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 4);
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
}
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
} // namespace mlx::core

View File

@@ -12,7 +12,7 @@
namespace mlx::core {
void steel_matmul(
void mlx_matmul(
const Stream& s,
metal::Device& d,
const array& a,

View File

@@ -42,15 +42,6 @@ MTL::CommandBuffer* increment_command_buffer(Stream s) {
return command_buffer;
}
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
msg << "[METAL] Command buffer execution failed: "
<< cbuf->error()->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
}
}
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
@@ -64,32 +55,27 @@ std::function<void()> make_task(
auto command_buffer = increment_command_buffer(s);
auto outputs = arr.outputs();
arr.primitive().eval_gpu(arr.inputs(), outputs);
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.push_back(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.push_back(s.data_shared_ptr());
}
if (!arr.is_tracer()) {
arr.detach();
}
if (p) {
metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers), p = std::move(p)](
MTL::CommandBuffer* cbuf) {
[s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable {
if (!arr.is_tracer()) {
arr.detach();
for (auto s : arr.siblings()) {
s.detach();
}
}
p->set_value();
scheduler::notify_task_completion(s);
check_error(cbuf);
});
metal::device(s.device).commit_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
[s, arr](MTL::CommandBuffer*) mutable {
if (!arr.is_tracer()) {
arr.detach();
}
});
}
};

View File

@@ -19,9 +19,6 @@ constexpr bool is_available() {
#endif
}
bool cache_enabled(void);
void set_cache_enabled(bool enabled);
void new_stream(Stream stream);
std::shared_ptr<void> new_scoped_memory_pool();

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
@@ -27,8 +27,8 @@ void binary_op(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
auto& out = outputs[0];
if (out.size() == 0) {
@@ -60,7 +60,7 @@ void binary_op(
break;
}
kname << op << type_to_name(a);
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
@@ -69,14 +69,8 @@ void binary_op(
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_a ? outputs[0] : a, 0);
set_array_buffer(
compute_encoder, donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, outputs[0], 2);
set_array_buffer(compute_encoder, outputs[1], 3);
@@ -128,7 +122,7 @@ void binary_op(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true);
set_binary_op_output_data(a, b, out, bopt);
if (out.size() == 0) {
return;
}
@@ -158,7 +152,7 @@ void binary_op(
break;
}
kname << op << type_to_name(a);
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
@@ -167,10 +161,8 @@ void binary_op(
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_a ? out : a, 0);
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, out, 2);
if (bopt == General) {
@@ -220,15 +212,11 @@ void unary_op(
auto& in = inputs[0];
bool contig = in.flags().contiguous;
if (contig) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
@@ -252,8 +240,7 @@ void unary_op(
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
if (!contig) {
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
@@ -486,18 +473,6 @@ void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cosh");
}
void CustomVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div");
}
@@ -752,12 +727,6 @@ void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sinh");
}
void Split::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "square");
}
@@ -794,10 +763,4 @@ void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void QRF::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI.");
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
@@ -52,7 +52,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = std::min(32, O);
int bo = 32;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
@@ -72,7 +72,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
else {
std::ostringstream kname;
kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
<< bits_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
@@ -85,7 +85,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int bn = 32;
int bk = 64;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
@@ -110,10 +110,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = std::min(32, O);
int bo = 32;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
MTL::Size grid_dims = MTL::Size(1, (w.shape(1) + bo - 1) / bo, B);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);

View File

@@ -28,40 +28,35 @@ inline auto safe_divup(size_t n, size_t m) {
return safe_div(n, m) * m;
}
inline bool is_64b_int(Dtype dtype) {
return dtype == int64 || dtype == uint64;
}
// All Reduce
void all_reduce_dispatch(
const array& in,
array& out,
const std::string& op_name,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d,
const Stream& s) {
Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel("all_reduce_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("all_reduce_" + op_name + type_to_name(in));
metal::Device& d) {
// Get kernel and encode buffers
size_t in_size = in.size();
auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
// Set grid dimensions
// We make sure each thread has enough to do by making it read in
// at least n_reads inputs
int n_reads = REDUCE_N_READS;
size_t in_size = in.size();
// mod_in_size gives us the groups of n_reads needed to go over the entire
// input
uint mod_in_size = (in_size + n_reads - 1) / n_reads;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
uint simd_size = kernel->threadExecutionWidth();
thread_group_size =
((thread_group_size + simd_size - 1) / simd_size) * simd_size;
// If the number of thread groups needed exceeds 1024, we reuse threads groups
uint n_thread_groups = safe_div(mod_in_size, thread_group_size);
@@ -71,52 +66,7 @@ void all_reduce_dispatch(
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
// Encode buffers and dispatch
if (is_out_64b_int == false || n_thread_groups == 1) {
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Allocate intermediate array to store partial reduction results
size_t intermediate_size = n_thread_groups;
array intermediate =
array({static_cast<int>(intermediate_size)}, out_dtype, nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
std::vector<array> intermediates = {intermediate};
// First dispatch
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Second pass to reduce intermediate reduction results written to DRAM
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2);
mod_in_size = (intermediate_size + n_reads - 1) / n_reads;
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
thread_group_size =
((thread_group_size + simd_size - 1) / simd_size) * simd_size;
// If the number of thread groups needed exceeds 1024, we reuse threads
// groups
nthreads = thread_group_size;
group_dims = MTL::Size(thread_group_size, 1, 1);
grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
intermediates.clear();
});
}
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void row_reduce_general_dispatch(
@@ -126,31 +76,22 @@ void row_reduce_general_dispatch(
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d,
const Stream& s) {
Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"row_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
metal::Device& d) {
auto kernel =
d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
// Prepare the arguments for the kernel
int n_reads = REDUCE_N_READS;
size_t reduction_size = plan.shape.back();
size_t out_size = out.size();
auto shape = plan.shape;
auto strides = plan.strides;
shape.pop_back();
strides.pop_back();
size_t non_row_reductions = 1;
for (auto s : shape) {
non_row_reductions *= static_cast<size_t>(s);
}
size_t out_size = out.size();
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
for (auto s : rem_shape) {
shape.push_back(s);
@@ -160,6 +101,16 @@ void row_reduce_general_dispatch(
}
int ndim = shape.size();
// Set the arguments for the kernel
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
@@ -176,88 +127,7 @@ void row_reduce_general_dispatch(
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
if (is_out_64b_int == false || non_row_reductions == 1) {
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Allocate intermediate array to store partial reduction results
array intermediate = array(
{static_cast<int>(out.size()), static_cast<int>(non_row_reductions)},
out_dtype,
nullptr,
{});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
std::vector<array> intermediates = {intermediate};
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Set up second dispatch
reduction_size = non_row_reductions;
out_size = 1;
// Shape of axes that aren't participating in reduction remains unchanged.
std::vector<int> new_shape = rem_shape;
// Update their strides since they'll be different post partial reduction in
// first compute dispatch.
std::vector<size_t> new_strides = rem_strides;
new_strides.back() = reduction_size;
for (int i = new_shape.size() - 2; i >= 0; i--) {
new_strides[i] = new_shape[i + 1] * new_strides[i + 1];
}
ndim = new_shape.size();
// Set the arguments for the kernel
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(
new_shape.data(), new_shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Align thread group size with simd_size
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
// Launch enough thread groups for each output
n_threads = thread_group_size;
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
intermediates.clear();
});
}
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void strided_reduce_general_dispatch(
@@ -267,16 +137,9 @@ void strided_reduce_general_dispatch(
const ReductionPlan& plan,
const std::vector<int>& axes,
MTL::ComputeCommandEncoder* compute_encoder,
metal::Device& d,
const Stream& s) {
Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
metal::Device& d) {
auto kernel =
d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
// Prepare the arguments for the kernel
size_t reduction_size = plan.shape.back();
@@ -299,7 +162,19 @@ void strided_reduce_general_dispatch(
}
int ndim = shape.size();
// Set the arguments for the kernel
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// Select block dimensions
// Each thread reads 16 inputs to give it more work
uint n_inputs_per_thread = REDUCE_N_READS;
uint n_threads_per_output =
@@ -308,22 +183,14 @@ void strided_reduce_general_dispatch(
// We spread outputs over the x dimension and inputs over the y dimension
// Threads with the same lid.x in a given threadgroup work on the same
// output and each thread in the y dimension accumulates for that output
// Threads with same lid.x, i.e. each column of threads work on same output
uint threadgroup_dim_x = std::min(out_size, 128ul);
// Number of threads along y, is dependent on number of reductions needed.
uint threadgroup_dim_y =
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y);
// Derive number of thread groups along x, based on how many threads we need
// along x
uint n_threadgroups_x =
(out_size + threadgroup_dim_x - 1) / threadgroup_dim_x;
// Derive number of thread groups along y based on how many threads we need
// along y
uint n_threadgroups_y =
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
@@ -332,122 +199,18 @@ void strided_reduce_general_dispatch(
MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions);
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
if (is_out_64b_int == false) {
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// We set shared memory to be exploited here for reductions within a
// threadgroup - each thread must be able to update its accumulated output
// Note: Each threadgroup should have 32kB of data in threadgroup memory
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
// This should be fine for floats, but we might need to revisit
// if we ever come to doubles. In that case, we should also cut
// down the number of threads we launch in a threadgroup
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
// We set shared memory to be exploited here for reductions within a
// threadgroup - each thread must be able to update its accumulated output
// Note: Each threadgroup should have 32kB of data in threadgroup memory
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
// This should be fine for floats, but we might need to revisit
// if we ever come to doubles. In that case, we should also cut
// down the number of threads we launch in a threadgroup
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} else {
// Allocate intermediate array to store reduction results from all thread
// groups
array intermediate = array(
{static_cast<int>(out.size()),
static_cast<int>(n_threadgroups_y * non_col_reductions)},
out_dtype,
nullptr,
{});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
std::vector<array> intermediates = {intermediate};
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// We set shared memory to be exploited here for reductions within a
// threadgroup - each thread must be able to update its accumulated output
// Note: Each threadgroup should have 32kB of data in threadgroup memory
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
// This should be fine for floats, but we might need to revisit
// if we ever come to doubles. In that case, we should also cut
// down the number of threads we launch in a threadgroup
compute_encoder->setThreadgroupMemoryLength(
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
// Perform second pass of reductions
// Reduce results of threadgroups along y, z from first pass, that
// collectively work on each output element.
reduction_size = n_threadgroups_y * non_col_reductions;
out_size = 1;
// Shape of axes that aren't participating in reduction remains unchanged.
std::vector<int> new_shape = rem_shape;
// Update their strides since they'll be different after a partial reduction
// post first compute dispatch.
std::vector<size_t> new_strides = rem_strides;
new_strides.back() = reduction_size;
for (int i = new_shape.size() - 2; i >= 0; i--) {
new_strides[i] = new_shape[i + 1] * new_strides[i + 1];
}
ndim = new_shape.size();
auto row_reduce_kernel = d.get_kernel(
"row_reduce_general_no_atomics_" + op_name +
type_to_name(intermediate));
compute_encoder->setComputePipelineState(row_reduce_kernel);
set_array_buffer(compute_encoder, intermediate, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(
new_shape.data(), new_shape.size() * sizeof(int), 4);
compute_encoder->setBytes(
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
// Each thread group is responsible for 1 output
size_t n_reads = REDUCE_N_READS;
size_t thread_group_size =
row_reduce_kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Align thread group size with simd_size
uint simd_size = row_reduce_kernel->threadExecutionWidth();
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
// Launch enough thread groups for each output
uint n_threads = thread_group_size;
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[intermediates](MTL::CommandBuffer*) mutable {
intermediates.clear();
});
}
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
} // namespace
@@ -460,6 +223,14 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
array in = inputs[0];
// TODO: Allow specific row and column reductions with types disabled
// due to atomics ?
if (size_of(in.dtype()) == 8) {
std::ostringstream msg;
msg << "[Reduce::eval_gpu] Does not support " << in.dtype();
throw std::runtime_error(msg.str());
}
// Make sure no identity reductions trickle down here
assert(!axes_.empty());
@@ -526,7 +297,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reducing over everything and the data is all there no broadcasting or
// slicing etc.
if (plan.type == ContiguousAllReduce) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
}
// At least the last dimension is row contiguous and we are reducing over
@@ -534,7 +305,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
else if (
plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
row_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d, s);
in, out, op_name, plan, axes_, compute_encoder, d);
}
// At least the last two dimensions are contiguous and we are doing a
@@ -543,7 +314,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
plan.type == ContiguousStridedReduce ||
plan.type == GeneralStridedReduce) {
strided_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d, s);
in, out, op_name, plan, axes_, compute_encoder, d);
}
if (!copies.empty()) {

View File

@@ -19,10 +19,4 @@ std::function<void()> make_task(
"[metal::make_task] Cannot make GPU task without metal backend");
}
// No cache for CPU only
bool cache_enabled(void) {
return false;
}
void set_cache_enabled(bool) {}
} // namespace mlx::core::metal

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include "mlx/primitives.h"
@@ -17,7 +17,6 @@ namespace mlx::core {
NO_GPU(Abs)
NO_GPU(Add)
NO_GPU(AddMM)
NO_GPU(Arange)
NO_GPU(ArcCos)
NO_GPU(ArcCosh)
@@ -37,8 +36,6 @@ NO_GPU(Convolution)
NO_GPU(Copy)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU_MULTI(CustomVJP)
NO_GPU_MULTI(Depends)
NO_GPU(Divide)
NO_GPU(Remainder)
NO_GPU(Equal)
@@ -83,7 +80,6 @@ NO_GPU(Sinh)
NO_GPU(Slice)
NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU_MULTI(Split)
NO_GPU(Square)
NO_GPU(Sqrt)
NO_GPU(StopGradient)
@@ -92,5 +88,5 @@ NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU_MULTI(DivMod)
NO_GPU_MULTI(QRF)
} // namespace mlx::core

View File

@@ -1,440 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include "mlx/allocator.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core {
namespace detail {
bool& compiler_disabled() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
return true;
} else {
return false;
}
};
static bool compiler_disabled_ = get_val();
return compiler_disabled_;
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
template <typename T, typename... U>
size_t getAddress(std::function<T(U...)> f) {
typedef T(fnType)(U...);
fnType** fnPointer = f.template target<fnType*>();
if (fnPointer == nullptr) {
throw std::invalid_argument(
"[compile] Cannot compile a non-addressable function.");
}
return (size_t)*fnPointer;
}
struct CompilerCache {
struct CacheEntry {
std::vector<array> inputs;
std::vector<array> outputs;
std::vector<array> tape;
bool empty{true};
};
// Returns a reference to a CacheEntry which can be updated
// by the caller to avoid copying large tapes / inputs / outputs
CacheEntry& find(size_t fun_id, const std::vector<array>& inputs) {
// Try to find the entry
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
auto& entries = entry_it->second;
auto is_match = [](const std::vector<array>& in1,
const std::vector<array>& in2) {
if (in1.size() != in2.size()) {
throw std::runtime_error(
"[compiler] Got different number of inputs to function,"
" this should never happen.");
}
for (int i = 0; i < in1.size(); ++i) {
if (in1[i].shape() != in2[i].shape()) {
return false;
}
if (in1[i].dtype() != in2[i].dtype()) {
return false;
}
}
return true;
};
// Loop over entries and check inputs match i.e. shapes and types must be
// equal. Note this could get really slow if one compiles the same
// function with many different shapes. May want to store entries in a
// more easily searchable structure.
for (auto& entry : entries) {
// Check the inputs match and return if so
if (is_match(inputs, entry.inputs)) {
return entry;
}
}
// Otherwise append a new cache entry
entries.push_back(CacheEntry{});
return entries.back();
};
void erase(size_t fun_id) {
cache_.erase(fun_id);
}
private:
CompilerCache() {
// Make sure the allocator is fully
// initialized before the compiler cache
allocator::allocator();
}
friend CompilerCache& compiler_cache();
std::unordered_map<size_t, std::vector<CacheEntry>> cache_;
};
CompilerCache& compiler_cache() {
static CompilerCache compiler_cache_;
return compiler_cache_;
}
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs) {
// Set the global tracing flag.
detail::InTracing in_tracing;
// Run the function on placeholder inputs
// to get compute graph
std::vector<array> tracer_inputs;
for (int i = 0; i < inputs.size(); ++i) {
array in(inputs[i].shape(), inputs[i].dtype(), nullptr, {});
in.set_tracer(true);
tracer_inputs.push_back(std::move(in));
}
return {tracer_inputs, fun(tracer_inputs)};
}
// Traverses the graph to build a tape and a map of array ids to their parents
std::pair<std::vector<array>, ParentsMap> compile_dfs(
const std::vector<array>& inputs,
const std::vector<array>& outputs) {
std::function<void(const array&)> recurse;
std::vector<array> tape;
std::unordered_set<std::uintptr_t> input_set;
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
parents_map;
for (int i = 0; i < inputs.size(); ++i) {
auto in = inputs[i];
input_set.insert(in.id());
}
// DFS the graph to build the tape, and log parents and scalars
std::unordered_set<std::uintptr_t> cache;
recurse = [&](const array& a) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
for (int i = 0; i < a.inputs().size(); i++) {
auto& in = a.inputs()[i];
parents_map[in.id()].push_back({a, i});
for (auto& s : a.siblings()) {
parents_map[in.id()].push_back({s, i});
}
// Don't recurse on inputs (but add them to the tape for the purpose
// of future optimizations)
if (input_set.find(a.id()) == input_set.end()) {
recurse(in);
}
}
cache.insert(id);
for (auto& s : a.siblings()) {
cache.insert(s.id());
}
tape.push_back(a);
};
for (auto& a : outputs) {
recurse(a);
}
return {tape, parents_map};
}
// Simplify the tape. Note, this function modifies in-place both the tape and
// the parents map to remove orphaned arrays
void compile_simplify(
std::vector<array>& tape,
ParentsMap& parents_map,
const std::vector<array>& outputs,
int passes) {
// Helpers to identify identical scalars
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
auto is_scalar = [](const array& a) {
return a.is_evaled() && a.ndim() == 0;
};
auto get_scalar_rep = [](const array& a) {
uint64_t v = 0;
int dtype;
switch (a.dtype().size) {
case 1:
v = *a.data<uint8_t>();
break;
case 4:
v = *a.data<uint32_t>();
break;
case 8:
v = *a.data<uint64_t>();
break;
}
return std::make_pair(v, a.dtype().val);
};
for (auto& a : tape) {
if (is_scalar(a)) {
scalars.insert({get_scalar_rep(a), a});
}
}
// Helper that fuses two arrays in the graph by setting the parents of the
// source to point to the destination
auto fuse = [&](array& dst, array& src) {
// Canonicalize the order of the primitives outputs
auto sources = src.outputs();
auto dests = dst.outputs();
// For each src parent, point it to the corresponding dest
for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id());
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}
};
// Depth-1 array equivalence check.
auto array_equivalent = [](const array& a, const array& b) {
if (!a.has_primitive() || !b.has_primitive()) {
return false;
}
if (a.primitive_id() == b.primitive_id()) {
return false;
}
const auto& pa = a.primitive();
const auto& pb = b.primitive();
if (typeid(pa) != typeid(pb)) {
return false;
}
if (a.inputs().size() != b.inputs().size()) {
return false;
}
for (int i = 0; i < a.inputs().size(); i++) {
if (a.inputs()[i].id() != b.inputs()[i].id()) {
return false;
}
}
return pa.is_equivalent(pb);
};
// Pass 0: fuse scalars
std::vector<array> new_tape;
for (auto& arr : tape) {
// Check if we can fuse scalars
if (is_scalar(arr)) {
auto scalar = scalars.find(get_scalar_rep(arr));
if (scalar->second.id() != arr.id()) {
fuse(scalar->second, arr);
// Don't keep orphaned scalars in the tape
continue;
}
}
new_tape.push_back(std::move(arr));
}
tape = std::move(new_tape);
std::unordered_set<uintptr_t> output_set;
for (auto& o : outputs) {
output_set.insert(o.id());
}
// Pass 1..passes: fuse only keeping non-orphaned arrays in the tape
for (int pass = 0; pass < passes; ++pass) {
for (auto& arr : tape) {
// Helper to check if we can fuse the parents of the
// given array
auto maybe_fuse_parents = [&](auto& a) {
auto parents = parents_map.find(a.id());
if (parents != parents_map.end()) {
auto N = parents->second.size();
std::vector<bool> mask(N, false);
for (int i = 0; i < N; i++) {
if (mask[i]) {
continue;
}
for (int j = i + 1; j < N; j++) {
if (mask[j]) {
continue;
}
auto& src = parents->second[j].first;
auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) {
fuse(dst, src);
mask[j] = true;
}
}
}
// Erase orphaned parents so we don't keep fusing with them
for (int i = N - 1; i > 0; --i) {
if (mask[i]) {
parents->second.erase(parents->second.begin() + i);
}
}
return false;
} else {
return output_set.find(a.id()) == output_set.end();
}
};
bool discard = maybe_fuse_parents(arr);
for (auto& s : arr.siblings()) {
discard &= maybe_fuse_parents(s);
}
// If an array and its siblings have no parents, and none of them are
// outputs, it is safe to remove it from the tape
if (!discard) {
new_tape.push_back(std::move(arr));
}
}
tape = std::move(new_tape);
}
}
std::vector<array> compile_replace(
const std::vector<array>& tape,
const std::vector<array>& trace_inputs,
const std::vector<array>& trace_outputs,
const std::vector<array>& inputs) {
std::unordered_map<uintptr_t, array> trace_to_real;
for (int i = 0; i < inputs.size(); ++i) {
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
}
for (auto& a : tape) {
// Arrays in the tape without primitives are constants
// and can be used directly
if (!a.has_primitive()) {
trace_to_real.insert({a.id(), a});
} else {
// Find real inputs
std::vector<array> real_inputs;
for (auto& in : a.inputs()) {
real_inputs.push_back(trace_to_real.at(in.id()));
}
if (a.siblings().empty()) {
auto real_a = array(
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
trace_to_real.insert({a.id(), std::move(real_a)});
} else {
// Ensure the order is correct for multi-output primitives
std::vector<std::vector<int>> shapes;
std::vector<Dtype> types;
auto trace_out = a.outputs();
for (auto& o : trace_out) {
shapes.push_back(o.shape());
types.push_back(o.dtype());
}
auto real_out =
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
for (int i = 0; i < trace_out.size(); ++i) {
trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
}
}
}
}
std::vector<array> outputs;
for (auto& o : trace_outputs) {
outputs.push_back(trace_to_real.at(o.id()));
}
return outputs;
}
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
size_t fun_id) {
if (compiler_disabled()) {
return fun;
}
return [fun, fun_id](const std::vector<array>& inputs) {
// Find a cache entry with the correct inputs
auto& entry = compiler_cache().find(fun_id, inputs);
// No matching cache entry existed, so compile
if (entry.empty) {
// Mark the entry as not empty since we are about to fill it
entry.empty = false;
// Trace to build the graph
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
// DFS the graph and get a tape, and a map of array id to (parent,
// position in parent inputs)
std::unordered_map<uintptr_t, std::vector<std::pair<array, int>>>
parents_map;
std::tie(entry.tape, parents_map) =
compile_dfs(entry.inputs, entry.outputs);
// Simplify the tape
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3);
// This is a good point to do more optimizations, e.g. kernel fusion to
// generate new primitives. The tape needs to be updated accordingly
}
// At this point we must have a tape, now replace the placeholders
// with real arrays that can be evaluated
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
};
}
void compile_erase(size_t fun_id) {
detail::compiler_cache().erase(fun_id);
}
} // namespace detail
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
if (detail::compiler_disabled()) {
return fun;
}
auto fun_id = detail::getAddress(fun);
return detail::compile(fun, fun_id);
}
void disable_compile() {
detail::compiler_disabled() = true;
}
void enable_compile() {
detail::compiler_disabled() = false;
}
} // namespace mlx::core

View File

@@ -15,8 +15,8 @@ namespace mlx::core {
struct NodeNamer {
std::unordered_map<std::uintptr_t, std::string> names;
std::string get_name(const array& x) {
auto it = names.find(x.id());
std::string get_name(uintptr_t id) {
auto it = names.find(id);
if (it == names.end()) {
// Get the next name in the sequence
// [A, B, ..., Z, AA, AB, ...]
@@ -27,11 +27,15 @@ struct NodeNamer {
var_num = (var_num - 1) / 26;
}
std::string name(letters.rbegin(), letters.rend());
names.insert({x.id(), name});
names.insert({id, name});
return name;
}
return it->second;
}
std::string get_name(const array& x) {
return get_name(x.id());
}
};
void depth_first_traversal(
@@ -120,14 +124,15 @@ void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
// Node for primitive
if (x.has_primitive()) {
os << "{ ";
os << x.primitive_id();
os << namer.get_name(x.primitive_id());
os << " [label =\"";
x.primitive().print(os);
os << "\", shape=rectangle]";
os << "; }" << std::endl;
// Arrows to primitive's inputs
for (auto& a : x.inputs()) {
os << namer.get_name(a) << " -> " << x.primitive_id() << std::endl;
os << namer.get_name(x.primitive_id()) << " -> "
<< namer.get_name(a) << std::endl;
}
}
@@ -140,7 +145,8 @@ void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
os << namer.get_name(a);
os << "; }" << std::endl;
if (x.has_primitive()) {
os << x.primitive_id() << " -> " << namer.get_name(a) << std::endl;
os << namer.get_name(a) << " -> "
<< namer.get_name(x.primitive_id()) << std::endl;
}
}
},

View File

@@ -1,55 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <variant>
#include "mlx/array.h"
#include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/stream.h"
namespace mlx::core {
/** Save array to out stream in .npy format */
void save(std::shared_ptr<io::Writer> out_stream, array a);
/** Save array to file in .npy format */
void save(const std::string& file, array a);
/** Load array from reader in .npy format */
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
/** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s = {});
/** Load array map from .safetensors file format */
std::unordered_map<std::string, array> load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s = {});
std::unordered_map<std::string, array> load_safetensors(
const std::string& file,
StreamOrDevice s = {});
void save_safetensors(
std::shared_ptr<io::Writer> in_stream,
std::unordered_map<std::string, array>);
void save_safetensors(
const std::string& file,
std::unordered_map<std::string, array>);
using MetaData =
std::variant<std::monostate, array, std::string, std::vector<std::string>>;
/** Load array map and metadata from .gguf file format */
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, MetaData>>
load_gguf(const std::string& file, StreamOrDevice s = {});
void save_gguf(
std::string file,
std::unordered_map<std::string, array> array_map,
std::unordered_map<std::string, MetaData> meta_data = {});
} // namespace mlx::core

View File

@@ -4,7 +4,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp
)
MESSAGE(STATUS "Downloading json")
@@ -15,11 +14,6 @@ target_include_directories(
$<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>
$<INSTALL_INTERFACE:include/json>
)
install(
DIRECTORY ${json_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/json
COMPONENT json_source
)
MESSAGE(STATUS "Downloading gguflib")
FetchContent_Declare(gguflib
@@ -32,12 +26,6 @@ target_include_directories(
$<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/gguflib>
)
install(
DIRECTORY ${gguflib_SOURCE_DIR}/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/gguflib
COMPONENT gguflib_source
)
add_library(
gguflib STATIC
${gguflib_SOURCE_DIR}/fp16.c

View File

@@ -1,16 +1,17 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cstdint>
#include <cstring>
#include <numeric>
#include <mlx/io/gguf.h>
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
extern "C" {
#include <gguflib.h>
}
namespace mlx::core {
// https://github.com/antirez/gguf-tools/blob/af7d88d808a7608a33723fba067036202910acb3/gguflib.h#L102-L108
constexpr int gguf_array_header_size = 12;
std::optional<uint32_t> dtype_to_gguf_tensor_type(const Dtype& dtype) {
switch (dtype) {
case float32:
@@ -45,15 +46,6 @@ std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {
}
}
std::vector<int> get_shape(const gguf_tensor& tensor) {
std::vector<int> shape;
// The dimension order in GGML is the reverse of the order used in MLX.
for (int i = tensor.ndim - 1; i >= 0; i--) {
shape.push_back(tensor.dim[i]);
}
return shape;
}
std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
std::optional<Dtype> equivalent_dtype = gguf_type_to_dtype(tensor->type);
// If there's an equivalent type, we can simply copy.
@@ -78,328 +70,46 @@ std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
return {buffer, float16};
}
void set_mx_value_from_gguf(
gguf_ctx* ctx,
uint32_t type,
gguf_value* val,
MetaData& value) {
switch (type) {
case GGUF_VALUE_TYPE_UINT8:
value = array(val->uint8, uint8);
break;
case GGUF_VALUE_TYPE_INT8:
value = array(val->int8, int8);
break;
case GGUF_VALUE_TYPE_UINT16:
value = array(val->uint16, uint16);
break;
case GGUF_VALUE_TYPE_INT16:
value = array(val->int16, int16);
break;
case GGUF_VALUE_TYPE_UINT32:
value = array(val->uint32, uint32);
break;
case GGUF_VALUE_TYPE_INT32:
value = array(val->int32, int32);
break;
case GGUF_VALUE_TYPE_UINT64:
value = array(val->uint64, uint64);
break;
case GGUF_VALUE_TYPE_INT64:
value = array(val->int64, int64);
break;
case GGUF_VALUE_TYPE_FLOAT32:
value = array(val->float32, float32);
break;
case GGUF_VALUE_TYPE_BOOL:
value = array(val->boolval, bool_);
break;
case GGUF_VALUE_TYPE_STRING:
value =
std::string(val->string.string, static_cast<int>(val->string.len));
break;
case GGUF_VALUE_TYPE_FLOAT64:
value = array(val->float64, float32);
break;
case GGUF_VALUE_TYPE_ARRAY: {
ctx->off += gguf_array_header_size; // Skip header
char* data = reinterpret_cast<char*>(val) + gguf_array_header_size;
auto size = static_cast<int>(val->array.len);
if (val->array.type == GGUF_VALUE_TYPE_ARRAY) {
throw std::invalid_argument(
"[load_gguf] Only supports loading 1-layer of nested arrays.");
}
switch (val->array.type) {
case GGUF_VALUE_TYPE_UINT8:
value = array(reinterpret_cast<uint8_t*>(data), {size}, uint8);
break;
case GGUF_VALUE_TYPE_INT8:
value = array(reinterpret_cast<int8_t*>(data), {size}, int8);
break;
case GGUF_VALUE_TYPE_UINT16:
value = array(reinterpret_cast<uint16_t*>(data), {size}, uint16);
break;
case GGUF_VALUE_TYPE_INT16:
value = array(reinterpret_cast<int16_t*>(data), {size}, int16);
break;
case GGUF_VALUE_TYPE_UINT32:
value = array(reinterpret_cast<uint32_t*>(data), {size}, uint32);
break;
case GGUF_VALUE_TYPE_INT32:
value = array(reinterpret_cast<int32_t*>(data), {size}, int32);
break;
case GGUF_VALUE_TYPE_UINT64:
value = array(reinterpret_cast<uint64_t*>(data), {size}, uint64);
break;
case GGUF_VALUE_TYPE_INT64:
value = array(reinterpret_cast<uint64_t*>(data), {size}, int64);
break;
case GGUF_VALUE_TYPE_FLOAT32:
value = array(reinterpret_cast<float*>(data), {size}, float32);
break;
case GGUF_VALUE_TYPE_BOOL:
value = array(reinterpret_cast<bool*>(data), {size}, bool_);
break;
case GGUF_VALUE_TYPE_STRING: {
std::vector<std::string> strs(size);
for (auto& str : strs) {
auto str_val = reinterpret_cast<gguf_string*>(data);
data += (str_val->len + sizeof(gguf_string));
str = std::string(str_val->string, static_cast<int>(str_val->len));
ctx->off += (str_val->len + sizeof(gguf_string));
}
value = std::move(strs);
break;
}
case GGUF_VALUE_TYPE_FLOAT64:
value = array(reinterpret_cast<double*>(data), {size}, float32);
break;
default:
throw std::runtime_error(
"[load_gguf] Multiple levels of nested arrays are not supported.");
}
break;
}
default:
throw std::runtime_error("[load_gguf] Received unexpected type.");
break;
}
if (type == GGUF_VALUE_TYPE_STRING) {
ctx->off += (sizeof(gguf_string) + std::get<std::string>(value).size());
} else if (auto pv = std::get_if<array>(&value); pv) {
ctx->off += pv->nbytes();
}
}
std::unordered_map<std::string, MetaData> load_metadata(gguf_ctx* ctx) {
std::unordered_map<std::string, MetaData> metadata;
gguf_key key;
while (gguf_get_key(ctx, &key)) {
std::string key_name = std::string(key.name, key.namelen);
auto& val = metadata.insert({key_name, MetaData{}}).first->second;
set_mx_value_from_gguf(ctx, key.type, key.val, val);
}
return metadata;
}
std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
std::unordered_map<std::string, array> array_map;
gguf_tensor tensor;
auto check_insert = [](auto inserted) {
if (!inserted.second) {
std::ostringstream msg;
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
<< " this can happend when loading quantized tensors.";
throw std::runtime_error(msg.str());
}
};
while (gguf_get_tensor(ctx, &tensor)) {
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||
tensor.type == GGUF_TYPE_Q8_0) {
gguf_load_quantized(array_map, tensor);
} else {
std::string name = std::string(tensor.name, tensor.namelen);
const auto& [data, dtype] = extract_tensor_data(&tensor);
array loaded_array = array(data, get_shape(tensor), dtype);
array_map.insert({name, loaded_array});
}
}
return array_map;
}
std::pair<
std::unordered_map<std::string, array>,
std::unordered_map<std::string, MetaData>>
load_gguf(const std::string& file, StreamOrDevice s) {
std::unordered_map<std::string, array> load_gguf(
const std::string& file,
StreamOrDevice s) {
std::unordered_map<std::string, array> result;
gguf_ctx* ctx = gguf_open(file.c_str());
if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed");
}
auto metadata = load_metadata(ctx);
auto arrays = load_arrays(ctx);
gguf_close(ctx);
return {arrays, metadata};
}
void append_kv_array(
gguf_ctx* ctx,
const std::string& key,
array& val,
uint32_t gguf_type) {
if (val.ndim() == 1) {
size_t gguf_size = val.nbytes() + gguf_array_header_size;
std::vector<char> val_vec(gguf_size);
gguf_value* gguf_val = reinterpret_cast<gguf_value*>(val_vec.data());
gguf_val->array.type = gguf_type;
gguf_val->array.len = val.size();
memcpy(
val_vec.data() + gguf_array_header_size,
val.data<char>(),
val.nbytes());
gguf_append_kv(
ctx,
key.c_str(),
key.length(),
GGUF_VALUE_TYPE_ARRAY,
reinterpret_cast<void*>(val_vec.data()),
gguf_size);
} else {
gguf_append_kv(
ctx,
key.c_str(),
key.length(),
gguf_type,
reinterpret_cast<void*>(val.data<char>()),
val.nbytes());
gguf_skip_key_values_section(ctx);
gguf_tensor tensor;
while (gguf_get_tensor(ctx, &tensor)) {
std::vector<int> shape;
// The dimension order in GGML is the reverse of the order used in MLX.
for (int i = tensor.ndim - 1; i >= 0; i--) {
shape.push_back(tensor.dim[i]);
}
const auto& [data, dtype] = extract_tensor_data(&tensor);
array loaded_array = array(data, shape, dtype);
std::string name = std::string(tensor.name, tensor.namelen);
result.insert({name, loaded_array});
}
gguf_close(ctx);
return result;
}
void save_gguf(
std::string file,
std::unordered_map<std::string, array> array_map,
std::unordered_map<std::string, MetaData> metadata /* = {} */) {
void save_gguf(std::string file, std::unordered_map<std::string, array> a) {
// Add .gguf to file name if it is not there
if (file.length() < 5 || file.substr(file.length() - 5, 5) != ".gguf") {
file += ".gguf";
}
gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE);
if (!ctx) {
throw std::runtime_error("[save_gguf] gguf_create failed");
}
auto string_to_gguf = [](char* dst, const std::string& src) {
gguf_string* val = reinterpret_cast<gguf_string*>(dst);
val->len = src.length();
memcpy(val->string, src.c_str(), src.length());
};
// Save any meta data
for (auto& [key, value] : metadata) {
if (auto pv = std::get_if<std::string>(&value); pv) {
const std::string& str = *pv;
size_t size = sizeof(gguf_string) + str.length();
std::vector<char> val_vec(size);
string_to_gguf(val_vec.data(), str);
gguf_append_kv(
ctx,
key.c_str(),
key.length(),
GGUF_VALUE_TYPE_STRING,
static_cast<void*>(val_vec.data()),
size);
} else if (auto pv = std::get_if<std::vector<std::string>>(&value); pv) {
const auto& str_vec = *pv;
auto mem_size = std::accumulate(
str_vec.begin(), str_vec.end(), 0, [](size_t accum, const auto& s) {
return accum + s.size();
});
mem_size += str_vec.size() * sizeof(gguf_string) + gguf_array_header_size;
std::vector<char> val_vec(mem_size);
gguf_value* val = reinterpret_cast<gguf_value*>(val_vec.data());
val->array.type = GGUF_VALUE_TYPE_STRING;
val->array.len = str_vec.size();
auto str_ptr = val_vec.data() + gguf_array_header_size;
for (auto& str : str_vec) {
string_to_gguf(str_ptr, str);
str_ptr += str.length() + sizeof(gguf_string);
}
gguf_append_kv(
ctx,
key.c_str(),
key.length(),
GGUF_VALUE_TYPE_ARRAY,
static_cast<void*>(val),
mem_size);
} else if (auto pv = std::get_if<array>(&value); pv) {
array v = *pv;
if (v.ndim() > 1) {
throw std::runtime_error(
"[save_gguf] Cannot save arrays with more than one dimension.");
}
if (v.size() == 0) {
throw std::runtime_error("[save_gguf] Cannot save empty arrays.");
}
eval(v);
if (!v.flags().row_contiguous) {
v = reshape(flatten(v), v.shape());
}
if (!v.flags().row_contiguous) {
throw std::runtime_error(
"[save_gguf] Cannot save non contiguous arrays.");
}
switch (v.dtype()) {
case float32:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_FLOAT32);
break;
case int64:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT64);
break;
case int32:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT32);
break;
case int16:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT16);
break;
case int8:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT8);
break;
case uint64:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT64);
break;
case uint32:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT32);
break;
case uint16:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT16);
break;
case uint8:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT8);
break;
case bool_:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_BOOL);
break;
default:
std::ostringstream msg;
msg << "[save_gguf] array type " << v.dtype()
<< " not support for metadata.";
throw std::invalid_argument(msg.str());
}
} else {
throw std::runtime_error(
"[save_gguf] Received unexpected type in metadata");
}
}
// Tensor offsets are relative to data section, so we start at offset 0.
uint64_t tensor_offset = 0;
// First, append the tensor info
for (auto& [key, arr] : array_map) {
for (auto& [key, arr] : a) {
arr.eval();
// Try to make it row contiguous
@@ -444,7 +154,7 @@ void save_gguf(
}
// Then, append the tensor weights
for (const auto& [key, arr] : array_map) {
for (const auto& [key, arr] : a) {
if (!gguf_append_tensor_data(ctx, (void*)arr.data<void>(), arr.nbytes())) {
throw std::runtime_error("[save_gguf] gguf_append_tensor_data failed");
}

View File

@@ -1,20 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/io.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/utils.h"
extern "C" {
#include <gguflib.h>
}
namespace mlx::core {
std::vector<int> get_shape(const gguf_tensor& tensor);
void gguf_load_quantized(
std::unordered_map<std::string, array>& a,
const gguf_tensor& tensor);
} // namespace mlx::core

View File

@@ -1,158 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdint>
#include <cstring>
#include <mlx/io/gguf.h>
namespace mlx::core {
void unpack_32_4(uint8_t* data, int8_t* dst) {
for (int64_t j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
if (j % 2 != 0) {
x <<= 4;
}
dst[j / 2] += x;
}
// Last 16 weights are in the higher bits
for (int64_t j = 0; j < 16; ++j) {
uint8_t x = (data[j + 2] >> 4);
if (j % 2 != 0) {
x <<= 4;
}
dst[8 + j / 2] += x;
}
}
// Extracts (weight, scales, biases) from Q4_0 tensors.
// Data layout is: |16 bit scale|32 x 4bit weights|.
void extract_q4_0_data(
const gguf_tensor& tensor,
array& weights_arr,
array& scales_arr,
array& biases_arr) {
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
auto data = static_cast<uint8_t*>(tensor.weights_data);
auto weights = weights_arr.data<int8_t>();
auto scales = scales_arr.data<float16_t>();
auto biases = biases_arr.data<float16_t>();
for (int64_t i = 0; i < scales_arr.size(); i++) {
scales[i] = *((float16_t*)data);
biases[i] = -8 * scales[i];
unpack_32_4(data, weights);
weights += 16;
data += bytes_per_block;
}
}
// Extracts (weight, scales, biases) from Q4_1 tensors.
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
void extract_q4_1_data(
const gguf_tensor& tensor,
array& weights_arr,
array& scales_arr,
array& biases_arr) {
const uint64_t bytes_per_block =
20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
auto data = static_cast<uint8_t*>(tensor.weights_data);
auto weights = weights_arr.data<int8_t>();
auto scales = scales_arr.data<float16_t>();
auto biases = biases_arr.data<float16_t>();
for (int64_t i = 0; i < scales_arr.size(); i++) {
scales[i] = *((float16_t*)data);
biases[i] = *((float16_t*)(data) + 1);
unpack_32_4(data, weights);
weights += 16;
data += bytes_per_block;
}
}
// Extracts (weight, scales, biases) from Q8_0 tensors.
// Data layout is: |16 bit scale|32 x 8bit weights|.
void extract_q8_0_data(
const gguf_tensor& tensor,
array& weights_arr,
array& scales_arr,
array& biases_arr) {
const uint64_t weights_per_block = 32;
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
auto data = static_cast<uint8_t*>(tensor.weights_data);
auto weights = weights_arr.data<int8_t>();
auto scales = scales_arr.data<float16_t>();
auto biases = biases_arr.data<float16_t>();
for (int64_t i = 0; i < scales_arr.size(); i++) {
uint8_t* block_data = data + i * bytes_per_block;
scales[i] = *((float16_t*)block_data);
biases[i] = -128 * scales[i];
for (int64_t j = 0; j < weights_per_block; ++j) {
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
// Original data is in int8_t, so we add a bias of -128 and invert the
// first bit.
x ^= 1 << 7;
weights[i * weights_per_block + j] = x;
}
}
}
void gguf_load_quantized(
std::unordered_map<std::string, array>& a,
const gguf_tensor& tensor) {
uint64_t weights_per_byte;
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1) {
weights_per_byte = 2;
} else { // tensor.type == GGUF_TYPE_Q8_0
weights_per_byte = 1;
}
std::string name = std::string(tensor.name, tensor.namelen);
std::vector<int> shape = get_shape(tensor);
const uint64_t weights_per_block = 32;
if (shape[shape.size() - 1] % weights_per_block != 0) {
std::ostringstream msg;
msg << "[load_gguf] tensor " << name
<< "has incompatible last dim shape: " << shape[shape.size() - 1];
throw std::runtime_error(msg.str());
}
const uint64_t num_blocks = tensor.num_weights / weights_per_block;
std::vector<int> weights_shape = shape;
weights_shape.back() /= (weights_per_byte * 4);
array weights(std::move(weights_shape), uint32, nullptr, {});
weights.set_data(allocator::malloc(weights.nbytes()));
// For scales and bias
shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block;
array scales(shape, float16, nullptr, {});
array biases(std::move(shape), float16, nullptr, {});
scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc(biases.nbytes()));
if (tensor.type == GGUF_TYPE_Q4_0) {
extract_q4_0_data(tensor, weights, scales, biases);
} else if (tensor.type == GGUF_TYPE_Q4_1) {
extract_q4_1_data(tensor, weights, scales, biases);
} else if (tensor.type == GGUF_TYPE_Q8_0) {
extract_q8_0_data(tensor, weights, scales, biases);
}
a.insert({name, weights});
auto check_insert = [](auto inserted) {
if (!inserted.second) {
std::ostringstream msg;
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
<< " this can happend when loading quantized tensors.";
throw std::runtime_error(msg.str());
}
};
const std::string weight_suffix = ".weight";
const std::string name_prefix =
name.substr(0, name.length() - weight_suffix.length());
check_insert(a.insert({name_prefix + ".scales", scales}));
check_insert(a.insert({name_prefix + ".biases", biases}));
}
} // namespace mlx::core

View File

@@ -3,8 +3,8 @@
#include <json.hpp>
#include <stack>
#include "mlx/io.h"
#include "mlx/io/load.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
using json = nlohmann::json;

View File

@@ -4,9 +4,8 @@
#include <ostream>
#include <vector>
#include "mlx/dtype.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core::linalg {
@@ -173,31 +172,4 @@ array norm(
return matrix_norm(a, ord, ax, keepdims, s);
}
std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::qr] Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::qr] Arrays must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(
"[linalg::qr] Support for non-square matrices NYI.");
}
auto out = array::make_arrays(
{a.shape(), a.shape()},
{a.dtype(), a.dtype()},
std::make_unique<QRF>(to_stream(s)),
{astype(a, a.dtype(), s)});
return std::make_pair(out[0], out[1]);
}
} // namespace mlx::core::linalg

View File

@@ -60,6 +60,4 @@ norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
return norm(a, std::vector<int>{axis}, keepdims, s);
}
std::pair<array, array> qr(const array& a, StreamOrDevice s = {});
} // namespace mlx::core::linalg

View File

@@ -6,7 +6,6 @@
#include "mlx/backend/metal/metal.h"
#include "mlx/device.h"
#include "mlx/fft.h"
#include "mlx/io.h"
#include "mlx/linalg.h"
#include "mlx/ops.h"
#include "mlx/random.h"

View File

@@ -1,6 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
// Copyright © 2023 Apple Inc.
#include <cmath>
#include <numeric>
#include <set>
@@ -8,7 +6,6 @@
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -17,7 +14,8 @@ namespace {
std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
const std::vector<int>& axes,
const std::vector<int>& shape) {
const std::vector<int>& shape,
bool keepdims) {
std::set<int> axes_set;
auto ndim = shape.size();
for (auto ax : axes) {
@@ -37,7 +35,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
for (int i = 0; i < ndim; ++i) {
if (axes_set.count(i) == 0) {
out_shape.push_back(shape[i]);
} else {
} else if (keepdims) {
out_shape.push_back(1);
}
}
@@ -80,14 +78,7 @@ array arange(
msg << bool_ << " not supported for arange.";
throw std::invalid_argument(msg.str());
}
if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) {
throw std::invalid_argument("[arange] Cannot compute length.");
}
double real_size = std::ceil((stop - start) / step);
if (std::isnan(real_size)) {
throw std::invalid_argument("[arange] Cannot compute length.");
}
int size = std::max(static_cast<int>(real_size), 0);
int size = std::max(static_cast<int>(std::ceil((stop - start) / step)), 0);
return array(
{size},
dtype,
@@ -190,9 +181,6 @@ array full(
const array& vals,
Dtype dtype,
StreamOrDevice s /* = {} */) {
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
throw std::invalid_argument("[full] Negative dimensions not allowed.");
}
auto in = broadcast_to(astype(vals, dtype, s), shape, s);
return array(shape, dtype, std::make_unique<Full>(to_stream(s)), {in});
}
@@ -228,22 +216,22 @@ array ones_like(const array& a, StreamOrDevice s /* = {} */) {
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
if (n <= 0 || m <= 0) {
throw std::invalid_argument("[eye] N and M must be positive integers.");
throw std::invalid_argument("N and M must be positive integers.");
}
array result = zeros({n, m}, dtype, s);
array result = zeros({n * m}, dtype, s);
if (k >= m || -k >= n) {
return result;
return reshape(result, {n, m}, s);
}
int diagonal_length = k >= 0 ? std::min(n, m - k) : std::min(n + k, m);
int start_index = (k >= 0) ? k : -k * m;
std::vector<array> indices;
auto s1 = std::max(0, -k);
auto s2 = std::max(0, k);
indices.push_back(arange(s1, diagonal_length + s1, int32, s));
indices.push_back(arange(s2, diagonal_length + s2, int32, s));
array ones_array = ones({diagonal_length, 1, 1}, dtype, s);
return scatter(result, indices, ones_array, {0, 1}, s);
array diag_indices_array = arange(
start_index, start_index + diagonal_length * (m + 1), m + 1, int32, s);
array ones_array = ones({diagonal_length, 1}, dtype, s);
result = scatter(result, diag_indices_array, ones_array, 0, s);
return reshape(result, {n, m}, s);
}
array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
@@ -256,7 +244,7 @@ array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {
return astype(greater_equal(l, r, s), type, s);
}
array tril(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) {
array tril(array x, int k, StreamOrDevice s /* = {} */) {
if (x.ndim() < 2) {
throw std::invalid_argument("[tril] array must be at least 2-D");
}
@@ -264,7 +252,7 @@ array tril(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) {
return where(mask, x, zeros_like(x, s), s);
}
array triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) {
array triu(array x, int k, StreamOrDevice s /* = {} */) {
if (x.ndim() < 2) {
throw std::invalid_argument("[triu] array must be at least 2-D");
}
@@ -585,29 +573,6 @@ std::vector<array> split(
<< " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
if (indices.empty()) {
return {a};
}
if (indices.size() < 10 &&
std::is_sorted(indices.begin(), indices.end(), std::less<>{}) &&
indices[0] > 0 && indices.back() < a.shape(ax)) {
std::vector<Dtype> dtypes(indices.size() + 1, a.dtype());
std::vector<std::vector<int>> shapes(indices.size() + 1, a.shape());
shapes[0][ax] = indices[0];
for (int i = 1; i < indices.size(); i++) {
shapes[i][ax] = indices[i] - indices[i - 1];
}
shapes.back()[ax] = a.shape(ax) - indices.back();
return array::make_arrays(
shapes,
dtypes,
std::make_shared<Split>(to_stream(s), indices, ax),
{a});
}
std::vector<array> res;
auto out_shape = a.shape();
auto start_indices = std::vector<int>(a.ndim(), 0);
@@ -673,27 +638,26 @@ array concatenate(
int axis,
StreamOrDevice s /* = {} */) {
if (arrays.size() == 0) {
throw std::invalid_argument(
"[concatenate] No arrays provided for concatenation");
throw std::invalid_argument("No arrays provided for concatenation");
}
// Normalize the given axis
auto ax = axis < 0 ? axis + arrays[0].ndim() : axis;
if (ax < 0 || ax >= arrays[0].ndim()) {
std::ostringstream msg;
msg << "[concatenate] Invalid axis (" << axis << ") passed to concatenate"
msg << "Invalid axis (" << axis << ") passed to concatenate"
<< " for array with shape " << arrays[0].shape() << ".";
throw std::invalid_argument(msg.str());
}
auto throw_invalid_shapes = [&]() {
std::ostringstream msg;
msg << "[concatenate] All the input array dimensions must match exactly "
<< "except for the concatenation axis. However, the provided shapes are ";
msg << "All the input array dimensions must match exactly except"
<< " for the concatenation axis. However, the provided shapes are ";
for (auto& a : arrays) {
msg << a.shape() << ", ";
}
msg << "and the concatenation axis is " << axis << ".";
msg << "and the concatenation axis is " << axis;
throw std::invalid_argument(msg.str());
};
@@ -702,13 +666,6 @@ array concatenate(
// Make the output shape and validate that all arrays have the same shape
// except for the concatenation axis.
for (auto& a : arrays) {
if (a.ndim() != shape.size()) {
std::ostringstream msg;
msg << "[concatenate] All the input arrays must have the same number of "
<< "dimensions. However, got arrays with dimensions " << shape.size()
<< " and " << a.ndim() << ".";
throw std::invalid_argument(msg.str());
}
for (int i = 0; i < a.ndim(); i++) {
if (i == ax) {
continue;
@@ -1124,30 +1081,9 @@ array array_equal(
}
array isnan(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
return full(a.shape(), false, bool_, s);
}
return not_equal(a, a, s);
}
array isinf(const array& a, StreamOrDevice s /* = {} */) {
return logical_or(isposinf(a, s), isneginf(a, s), s);
}
array isposinf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(std::numeric_limits<float>::infinity(), a.dtype()), s);
}
array isneginf(const array& a, StreamOrDevice s /* = {} */) {
if (is_integral(a.dtype())) {
return full(a.shape(), false, bool_, s);
}
return equal(a, array(-std::numeric_limits<float>::infinity(), a.dtype()), s);
}
array where(
const array& condition,
const array& x,
@@ -1163,43 +1099,11 @@ array allclose(
const array& b,
double rtol /* = 1e-5 */,
double atol /* = 1e-8 */,
bool equal_nan /* = false */,
StreamOrDevice s /* = {}*/) {
return all(isclose(a, b, rtol, atol, equal_nan, s), s);
}
array isclose(
const array& a,
const array& b,
double rtol /* = 1e-5 */,
double atol /* = 1e-8 */,
bool equal_nan /* = false */,
StreamOrDevice s /* = {}*/) {
// |a - b| <= atol + rtol * |b|
auto rhs = add(array(atol), multiply(array(rtol), abs(b, s), s), s);
auto lhs = abs(subtract(a, b, s), s);
auto out = less_equal(lhs, rhs, s);
// Correct the result for infinite values.
auto any_inf = logical_or(isinf(a, s), isinf(b, s), s);
auto both_inf = logical_or(
logical_and(isposinf(a, s), isposinf(b, s), s),
logical_and(isneginf(a, s), isneginf(b, s), s),
s);
// Convert all elements where either value is infinite to False.
out = logical_and(out, logical_not(any_inf, s), s);
// Convert all the elements where both values are infinite and of the same
// sign to True.
out = logical_or(out, both_inf, s);
if (equal_nan) {
auto both_nan = logical_and(isnan(a, s), isnan(b, s), s);
out = logical_or(out, both_nan, s);
}
return out;
return all(less_equal(lhs, rhs, s), s);
}
array all(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
@@ -1216,16 +1120,13 @@ array all(
if (axes.empty()) {
return astype(a, bool_, s);
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
auto [out_shape, sorted_axes] =
compute_reduce_shape(axes, a.shape(), keepdims);
return array(
out_shape,
bool_,
std::make_unique<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array all(
@@ -1250,16 +1151,13 @@ array any(
if (axes.empty()) {
return astype(a, bool_, s);
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
auto [out_shape, sorted_axes] =
compute_reduce_shape(axes, a.shape(), keepdims);
return array(
out_shape,
bool_,
std::make_unique<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array any(
@@ -1284,17 +1182,14 @@ array sum(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto [out_shape, sorted_axes] =
compute_reduce_shape(axes, a.shape(), keepdims);
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
auto out = array(
return array(
out_shape,
out_type,
std::make_unique<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array sum(
@@ -1345,18 +1240,11 @@ array var(
bool keepdims /* = false */,
int ddof /* = 0*/,
StreamOrDevice s /* = {}*/) {
auto nelements = compute_number_of_elements(a, axes);
auto dtype = at_least_float(a.dtype());
auto mu2 = square(mean(a, axes, keepdims, s), s);
auto a2 = mean(square(a, s), axes, keepdims, s);
auto v = subtract(a2, mu2, s);
if (ddof != 0) {
auto nelements = compute_number_of_elements(a, axes);
float factor = nelements / (nelements - ddof);
v = multiply(v, array(factor, dtype), s);
}
return v;
auto mu = mean(a, axes, true, s);
auto S = sum(square(subtract(a, mu, s), s), axes, keepdims, s);
return multiply(S, array(1.0 / (nelements - ddof), dtype), s);
}
array var(
@@ -1382,16 +1270,13 @@ array prod(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
auto [out_shape, sorted_axes] =
compute_reduce_shape(axes, a.shape(), keepdims);
return array(
out_shape,
a.dtype(),
std::make_unique<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array prod(
@@ -1419,16 +1304,13 @@ array max(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
auto [out_shape, sorted_axes] =
compute_reduce_shape(axes, a.shape(), keepdims);
return array(
out_shape,
a.dtype(),
std::make_unique<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array max(
@@ -1456,16 +1338,13 @@ array min(
if (axes.empty()) {
return a;
}
auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape());
auto out = array(
auto [out_shape, sorted_axes] =
compute_reduce_shape(axes, a.shape(), keepdims);
return array(
out_shape,
a.dtype(),
std::make_unique<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array min(
@@ -1494,17 +1373,14 @@ array argmin(
throw std::invalid_argument(
"[argmin] Cannot argmin reduce zero size array.");
}
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
auto out = array(
auto [out_shape, sorted_axes] =
compute_reduce_shape({axis}, a.shape(), keepdims);
return array(
out_shape,
uint32,
std::make_unique<ArgReduce>(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
@@ -1525,17 +1401,14 @@ array argmax(
throw std::invalid_argument(
"[argmax] Cannot argmax reduce zero size array.");
}
auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape());
auto out = array(
auto [out_shape, sorted_axes] =
compute_reduce_shape({axis}, a.shape(), keepdims);
return array(
out_shape,
uint32,
std::make_unique<ArgReduce>(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
}
return out;
}
/** Returns a sorted copy of the flattened array. */
@@ -2691,40 +2564,9 @@ inline std::vector<int> conv_out_shape(
std::vector<int> out_shape(in_shape.size());
int i = 0;
out_shape[i++] = N;
for (; i < in_shape.size() - 1; i++) {
if (pads[i - 1] < 0) {
std::ostringstream msg;
msg << "[conv] Padding sizes must be non-negative."
<< " Got padding " << pads << ".";
throw std::invalid_argument(msg.str());
}
if (strides[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Stride sizes must be positive."
<< " Got strides " << strides << ".";
throw std::invalid_argument(msg.str());
}
if (dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Dilation sizes must be positive."
<< " Got dilation " << dilation << ".";
throw std::invalid_argument(msg.str());
}
out_shape[i] = conv_out_axis_size(
in_shape[i], wt_shape[i], strides[i - 1], pads[i - 1], dilation[i - 1]);
if (out_shape[i] <= 0) {
std::ostringstream msg;
msg << "[conv] Spatial dimensions of input after padding "
<< " cannot be smaller than weight spatial dimensions."
<< " Got input with shape " << in_shape << " and padding " << pads
<< " for weight of shape " << wt_shape << ".";
throw std::invalid_argument(msg.str());
}
}
out_shape[i] = O;
@@ -2955,25 +2797,16 @@ std::tuple<array, array, array> quantize(
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
if (group_size != 32 && group_size != 64 && group_size != 128) {
std::ostringstream msg;
msg << "[quantize] The requested group size " << group_size
<< " is not supported. The supported group sizes are 64 and 128.";
throw std::invalid_argument(msg.str());
}
if (bits != 2 && bits != 4 && bits != 8) {
std::ostringstream msg;
msg << "[quantize] The requested number of bits " << bits
<< " is not supported. The supported bits are 2, 4 and 8.";
throw std::invalid_argument(msg.str());
}
if (w.ndim() != 2) {
throw std::invalid_argument("[quantize] Only matrices supported for now");
}
if ((w.shape(1) % group_size) != 0) {
if ((w.shape(0) % 32) != 0) {
throw std::invalid_argument(
"[quantize] All dimensions should be divisible by 32 for now");
}
if ((w.shape(-1) % group_size) != 0) {
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size
@@ -2988,20 +2821,6 @@ std::tuple<array, array, array> quantize(
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
shifts = reshape(shifts, {1, 1, -1}, s);
// Check that the w matrix will fill up a whole SIMD.
// This is an implementation detail which should be removed in the future but
// at least we bail out early which will result in a nice readable error.
//
// Hopefully nobody is quantizing matrices that small anyway.
if (w.shape(1) < 32 * el_per_int) {
std::ostringstream msg;
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
<< "too small for quantization. We support >=512 for 2 bits, "
<< ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has "
<< "shape " << w.shape() << ".";
throw std::invalid_argument(msg.str());
}
// Compute scales and biases
array packed_w =
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
@@ -3028,20 +2847,15 @@ array dequantize(
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
if (bits <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for bits: " << bits;
throw std::invalid_argument(msg.str());
}
if (group_size <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for group_size: " << group_size;
throw std::invalid_argument(msg.str());
}
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) {
throw std::invalid_argument("[dequantize] Only matrices supported for now");
}
if ((w.shape(0) % 32) != 0) {
throw std::invalid_argument(
"[dequantize] All dimensions should be divisible by 32 for now");
}
if (w.shape(0) != scales.shape(0) || w.shape(0) != biases.shape(0)) {
throw std::invalid_argument(
"[dequantize] Shape of scales and biases does not match the matrix");
@@ -3189,196 +3003,4 @@ array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) {
return tensordot(a, b, {{-1}, {-1}}, s);
}
/** Compute D = beta * C + alpha * (A @ B) */
array addmm(
array c,
array a,
array b,
const float& alpha /* = 1.f */,
const float& beta /* = 1.f */,
StreamOrDevice s /* = {} */) {
// Divert in the case of vector-matrix multiplication
// TODO: Add the needed specializtion
if (a.ndim() == 1 || b.ndim() == 1) {
array X = matmul(a, b, s);
array alpha_arr = array(alpha, X.dtype());
array aX = multiply(alpha_arr, X, s);
array beta_arr = array(beta, c.dtype());
array bY = multiply(beta_arr, c, s);
return add(aX, bY, s);
}
if (a.ndim() == 0 || b.ndim() == 0) {
throw std::invalid_argument(
"[addmm] Got 0 dimension input. Inputs must "
"have at least one dimension.");
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[addmm] Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
// Type promotion
auto out_type = result_type({a, b, c});
if (!is_floating_point(out_type) || is_complex(out_type)) {
std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but "
<< c.dtype() << ", " << a.dtype() << " and " << b.dtype()
<< " were provided which results in " << out_type
<< ", which is not a real floating point type.";
throw std::invalid_argument(msg.str());
}
a = astype(a, out_type, s);
b = astype(b, out_type, s);
c = astype(c, out_type, s);
// We can batch the multiplication by reshaping a
if (a.ndim() > 2 && b.ndim() == 2 && c.ndim() <= 1) {
std::vector<int> out_shape = a.shape();
a = reshape(a, {-1, out_shape.back()}, s);
out_shape.back() = b.shape(-1);
c = broadcast_to(c, {a.shape(0), b.shape(1)}, s);
auto out = array(
{a.shape(0), b.shape(1)},
out_type,
std::make_unique<AddMM>(to_stream(s), alpha, beta),
{a, b, c});
return reshape(out, out_shape, s);
}
if (a.ndim() > 2 || b.ndim() > 2) {
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
// Broadcast a
inner_shape.push_back(a.shape(-2));
inner_shape.push_back(a.shape(-1));
a = broadcast_to(a, inner_shape, s);
// Broadcast b
*(inner_shape.end() - 2) = b.shape(-2);
*(inner_shape.end() - 1) = b.shape(-1);
b = broadcast_to(b, inner_shape, s);
}
auto out_shape = a.shape();
out_shape.back() = b.shape(-1);
auto c_broadcast_shape = broadcast_shapes(c.shape(), out_shape);
c = broadcast_to(c, c_broadcast_shape, s);
auto out = array(
out_shape,
out_type,
std::make_unique<AddMM>(to_stream(s), alpha, beta),
{a, b, c});
return out;
}
array diagonal(
const array& a,
int offset /* = 0 */,
int axis1 /* = 0 */,
int axis2 /* = 1 */,
StreamOrDevice s /* = {} */
) {
int ndim = a.ndim();
if (ndim < 2) {
std::ostringstream msg;
msg << "[diagonal] Array must have at least two dimensions, but got "
<< ndim << " dimensions.";
throw std::invalid_argument(msg.str());
}
auto ax1 = (axis1 < 0) ? axis1 + ndim : axis1;
if (ax1 < 0 || ax1 >= ndim) {
std::ostringstream msg;
msg << "[diagonal] Invalid axis1 " << axis1 << " for array with " << ndim
<< " dimensions.";
throw std::out_of_range(msg.str());
}
auto ax2 = (axis2 < 0) ? axis2 + ndim : axis2;
if (ax2 < 0 || ax2 >= ndim) {
std::ostringstream msg;
msg << "[diagonal] Invalid axis2 " << axis2 << " for array with " << ndim
<< " dimensions.";
throw std::out_of_range(msg.str());
}
if (ax1 == ax2) {
throw std::invalid_argument(
"[diagonal] axis1 and axis2 cannot be the same axis");
}
auto off1 = std::max(-offset, 0);
auto off2 = std::max(offset, 0);
auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2);
diag_size = std::max(diag_size, 0);
std::vector<array> indices = {
arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)};
std::vector<int> slice_sizes = a.shape();
slice_sizes[ax1] = 1;
slice_sizes[ax2] = 1;
auto out = gather(a, indices, {ax1, ax2}, slice_sizes, s);
return moveaxis(squeeze(out, {ax1 + 1, ax2 + 1}, s), 0, -1, s);
}
array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) {
if (a.ndim() == 1) {
int a_size = a.size();
int n = a_size + std::abs(k);
auto res = zeros({n, n}, a.dtype(), s);
std::vector<array> indices;
auto s1 = std::max(0, -k);
auto s2 = std::max(0, k);
indices.push_back(arange(s1, a_size + s1, uint32, s));
indices.push_back(arange(s2, a_size + s2, uint32, s));
return scatter(res, indices, reshape(a, {a_size, 1, 1}, s), {0, 1}, s);
} else if (a.ndim() == 2) {
return diagonal(a, k, 0, 1, s);
} else {
std::ostringstream msg;
msg << "[diag] array must be 1-D or 2-D, got array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
}
std::vector<array> depends(
const std::vector<array>& inputs,
const std::vector<array>& dependencies) {
std::vector<array> all_inputs = inputs;
all_inputs.insert(all_inputs.end(), dependencies.begin(), dependencies.end());
// Compute the stream. Maybe do it in a smarter way at some point in the
// future.
Stream s = (inputs[0].has_primitive()) ? inputs[0].primitive().stream()
: to_stream({});
// Make the output info
std::vector<std::vector<int>> shapes;
std::vector<Dtype> dtypes;
for (const auto& in : inputs) {
shapes.emplace_back(in.shape());
dtypes.emplace_back(in.dtype());
}
return array::make_arrays(
shapes, dtypes, std::make_shared<Depends>(to_stream(s)), all_inputs);
}
} // namespace mlx::core

View File

@@ -1,13 +1,14 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#pragma once
#include <optional>
#include <variant>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/stream.h"
#include "array.h"
#include "device.h"
#include "io/load.h"
#include "stream.h"
namespace mlx::core {
@@ -123,8 +124,8 @@ inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
return tri(n, n, 0, type, s);
}
array tril(array x, int k = 0, StreamOrDevice s = {});
array triu(array x, int k = 0, StreamOrDevice s = {});
array tril(array x, int k, StreamOrDevice s = {});
array triu(array x, int k, StreamOrDevice s = {});
/** array manipulation */
@@ -377,12 +378,6 @@ array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
array isnan(const array& a, StreamOrDevice s = {});
array isinf(const array& a, StreamOrDevice s = {});
array isposinf(const array& a, StreamOrDevice s = {});
array isneginf(const array& a, StreamOrDevice s = {});
/** Select from x or y depending on condition. */
array where(
const array& condition,
@@ -404,17 +399,6 @@ array allclose(
const array& b,
double rtol = 1e-5,
double atol = 1e-8,
bool equal_nan = false,
StreamOrDevice s = {});
/** Returns a boolean array where two arrays are element-wise equal within the
* specified tolerance. */
array isclose(
const array& a,
const array& b,
double rtol = 1e-5,
double atol = 1e-8,
bool equal_nan = false,
StreamOrDevice s = {});
/**
@@ -1050,6 +1034,20 @@ array conv2d(
int groups = 1,
StreamOrDevice s = {});
/** Serialization operations */
/** Save array to out stream in .npy format */
void save(std::shared_ptr<io::Writer> out_stream, array a);
/** Save array to file in .npy format */
void save(const std::string& file, array a);
/** Load array from reader in .npy format */
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
/** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s = {});
/** Quantized matmul multiplies x with a quantized matrix w*/
array quantized_matmul(
const array& x,
@@ -1096,33 +1094,26 @@ array outer(const array& a, const array& b, StreamOrDevice s = {});
/** Compute the inner product of two vectors. */
array inner(const array& a, const array& b, StreamOrDevice s = {});
/** Compute D = beta * C + alpha * (A @ B) */
array addmm(
array c,
array a,
array b,
const float& alpha = 1.f,
const float& beta = 1.f,
/** Load array map from .safetensors file format */
std::unordered_map<std::string, array> load_safetensors(
std::shared_ptr<io::Reader> in_stream,
StreamOrDevice s = {});
std::unordered_map<std::string, array> load_safetensors(
const std::string& file,
StreamOrDevice s = {});
/** Extract a diagonal or construct a diagonal array */
array diagonal(
const array& a,
int offset = 0,
int axis1 = 0,
int axis2 = 1,
void save_safetensors(
std::shared_ptr<io::Writer> in_stream,
std::unordered_map<std::string, array>);
void save_safetensors(
const std::string& file,
std::unordered_map<std::string, array>);
/** Load array map from .gguf file format */
std::unordered_map<std::string, array> load_gguf(
const std::string& file,
StreamOrDevice s = {});
/** Extract diagonal from a 2d array or create a diagonal matrix. */
array diag(const array& a, int k = 0, StreamOrDevice s = {});
/**
* Implements the identity function but allows injecting dependencies to other
* arrays. This ensures that these other arrays will have been computed
* when the outputs of this function are computed.
*/
std::vector<array> depends(
const std::vector<array>& inputs,
const std::vector<array>& dependencies);
void save_gguf(std::string file, std::unordered_map<std::string, array> a);
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
@@ -51,31 +51,29 @@ std::tuple<array, array, int> vmap_binary_op(
} // namespace
std::vector<array> Primitive::jvp(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&) {
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
throw std::invalid_argument("Primitive's jvp not implemented.");
};
std::vector<array> Primitive::vjp(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&,
const std::vector<array>&) {
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
throw std::invalid_argument("Primitive's vjp not implemented.");
};
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
const std::vector<array>&,
const std::vector<int>&) {
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::invalid_argument("Primitive's vmap not implemented.");
};
std::vector<array> Abs::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -108,8 +106,7 @@ std::vector<array> Add::jvp(
std::vector<array> Add::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
if (argnums.size() == 1) {
return cotangents;
} else {
@@ -124,52 +121,6 @@ std::pair<std::vector<array>, std::vector<int>> Add::vmap(
return {{add(a, b, stream())}, {to_ax}};
}
std::vector<array> AddMM::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
auto& cotan = cotangents[0];
std::vector<int> reorder(cotan.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
std::iter_swap(reorder.end() - 1, reorder.end() - 2);
for (auto arg : argnums) {
if (arg == 0) {
// M X N * (K X N).T -> M X K
auto cotan_scaled = cotan;
if (alpha_ != 1.) {
auto alpha_arr = array(alpha_, cotan.dtype());
cotan_scaled = (multiply(alpha_arr, cotan_scaled, stream()));
}
vjps.push_back(matmul(
cotan_scaled, transpose(primals[1], reorder, stream()), stream()));
} else if (arg == 1) {
// (M X K).T * M X N -> K X N
auto cotan_scaled = cotan;
if (alpha_ != 1.) {
auto alpha_arr = array(alpha_, cotan.dtype());
cotan_scaled = (multiply(alpha_arr, cotan_scaled, stream()));
}
vjps.push_back(matmul(
transpose(primals[0], reorder, stream()), cotan_scaled, stream()));
} else {
auto cotan_scaled = cotan;
if (beta_ != 1.) {
auto beta_arr = array(beta_, cotan.dtype());
cotan_scaled = (multiply(beta_arr, cotan_scaled, stream()));
}
vjps.push_back(cotan_scaled);
}
}
return vjps;
}
bool AddMM::is_equivalent(const Primitive& other) const {
const AddMM& a_other = static_cast<const AddMM&>(other);
return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);
}
bool Arange::is_equivalent(const Primitive& other) const {
const Arange& a_other = static_cast<const Arange&>(other);
return (
@@ -180,8 +131,7 @@ bool Arange::is_equivalent(const Primitive& other) const {
std::vector<array> ArcCos::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -208,8 +158,7 @@ std::pair<std::vector<array>, std::vector<int>> ArcCos::vmap(
std::vector<array> ArcCosh::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -235,8 +184,7 @@ std::pair<std::vector<array>, std::vector<int>> ArcCosh::vmap(
std::vector<array> ArcSin::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -262,8 +210,7 @@ std::pair<std::vector<array>, std::vector<int>> ArcSin::vmap(
std::vector<array> ArcSinh::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -289,8 +236,7 @@ std::pair<std::vector<array>, std::vector<int>> ArcSinh::vmap(
std::vector<array> ArcTan::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -316,8 +262,7 @@ std::pair<std::vector<array>, std::vector<int>> ArcTan::vmap(
std::vector<array> ArcTanh::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -360,20 +305,6 @@ bool ArgReduce::is_equivalent(const Primitive& other) const {
return reduce_type_ == r_other.reduce_type_ && axis_ == r_other.axis_;
}
std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
int reduce_ax = axis_ + (axis_ >= axes[0]);
auto& in = inputs[0];
std::vector<array> out;
if (reduce_type_ == ArgReduce::ArgMin) {
out.push_back(argmin(in, reduce_ax, true, stream()));
} else {
out.push_back(argmax(in, reduce_ax, true, stream()));
}
return {out, axes};
}
std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -391,8 +322,7 @@ bool ArgSort::is_equivalent(const Primitive& other) const {
std::vector<array> AsType::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
if (cotangents[0].dtype() != dtype_) {
throw std::invalid_argument(
"[astype] Type of cotangentsgent does not much primal output type.");
@@ -421,8 +351,7 @@ bool AsType::is_equivalent(const Primitive& other) const {
std::vector<array> AsStrided::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
assert(argnums.size() == 1);
// Extract the sizes and cast them to ints
@@ -466,8 +395,7 @@ bool AsStrided::is_equivalent(const Primitive& other) const {
std::vector<array> Broadcast::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
assert(argnums.size() == 1);
// Reduce cotangents to the shape of the primal
@@ -517,8 +445,7 @@ bool Broadcast::is_equivalent(const Primitive& other) const {
std::vector<array> Ceil::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -542,8 +469,7 @@ std::pair<std::vector<array>, std::vector<int>> Ceil::vmap(
std::vector<array> Concatenate::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
auto& cotan = cotangents[0];
std::vector<int> start(cotan.ndim(), 0);
std::vector<int> stop = cotan.shape();
@@ -618,8 +544,7 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
std::vector<array> Convolution::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
assert(primals.size() == 2);
std::vector<array> grads;
@@ -736,8 +661,7 @@ bool Convolution::is_equivalent(const Primitive& other) const {
std::vector<array> Copy::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return cotangents;
@@ -763,8 +687,7 @@ std::pair<std::vector<array>, std::vector<int>> Copy::vmap(
std::vector<array> Cos::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return {jvp(primals, cotangents, argnums)};
}
@@ -789,8 +712,7 @@ std::pair<std::vector<array>, std::vector<int>> Cos::vmap(
std::vector<array> Cosh::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -811,48 +733,10 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
return {{cosh(inputs[0], stream())}, axes};
}
std::vector<array> CustomVJP::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
std::vector<array> inputs(primals.begin(), primals.end() - outputs.size());
auto all_vjps = vjp_fun_(inputs, cotangents, outputs);
for (const auto& cot : cotangents) {
all_vjps.emplace_back(cot);
}
std::vector<array> vjps;
vjps.reserve(argnums.size());
for (auto arg : argnums) {
vjps.push_back(all_vjps[arg]);
}
return vjps;
}
std::vector<array> Depends::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg < cotangents.size()) {
vjps.push_back(cotangents[arg]);
} else {
vjps.push_back(zeros_like(primals[arg]));
}
}
return vjps;
}
std::vector<array> Divide::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg == 0) {
@@ -872,8 +756,7 @@ std::vector<array> Divide::vjp(
std::vector<array> DivMod::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@@ -929,8 +812,7 @@ std::pair<std::vector<array>, std::vector<int>> Divide::vmap(
std::vector<array> Remainder::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg == 0) {
@@ -983,8 +865,7 @@ std::pair<std::vector<array>, std::vector<int>> Equal::vmap(
std::vector<array> Equal::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@@ -1003,8 +884,7 @@ std::vector<array> Equal::jvp(
std::vector<array> Erf::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -1033,13 +913,8 @@ std::pair<std::vector<array>, std::vector<int>> Erf::vmap(
std::vector<array> ErfInv::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto dtype = primals[0].dtype();
auto scale =
multiply(array(1.0 / M_2_SQRTPI, dtype), cotangents[0], stream());
return {
multiply(scale, exp(square(outputs[0], stream()), stream()), stream())};
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
std::vector<array> ErfInv::jvp(
@@ -1067,9 +942,8 @@ std::pair<std::vector<array>, std::vector<int>> ErfInv::vmap(
std::vector<array> Exp::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
return {multiply(cotangents[0], outputs[0], stream())};
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
std::vector<array> Exp::jvp(
@@ -1123,8 +997,7 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
std::vector<array> FFT::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
auto& in = primals[0];
@@ -1177,8 +1050,7 @@ std::vector<array> FFT::jvp(
std::vector<array> Floor::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -1202,8 +1074,7 @@ std::pair<std::vector<array>, std::vector<int>> Floor::vmap(
std::vector<array> Full::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {multiply(cotangents[0], primals[0], stream())};
@@ -1284,8 +1155,7 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
std::vector<array> Gather::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
if (argnums.size() > 1 || argnums[0] != 0) {
throw std::invalid_argument(
"[gather] Cannot calculate VJP with respect to indices.");
@@ -1322,8 +1192,7 @@ std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
std::vector<array> Greater::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@@ -1349,8 +1218,7 @@ std::pair<std::vector<array>, std::vector<int>> GreaterEqual::vmap(
std::vector<array> GreaterEqual::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@@ -1376,8 +1244,7 @@ std::pair<std::vector<array>, std::vector<int>> Less::vmap(
std::vector<array> Less::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@@ -1403,8 +1270,7 @@ std::pair<std::vector<array>, std::vector<int>> LessEqual::vmap(
std::vector<array> LessEqual::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@@ -1423,8 +1289,7 @@ std::vector<array> LessEqual::jvp(
std::vector<array> Log::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -1460,8 +1325,7 @@ std::pair<std::vector<array>, std::vector<int>> Log::vmap(
std::vector<array> Log1p::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -1487,8 +1351,7 @@ std::pair<std::vector<array>, std::vector<int>> Log1p::vmap(
std::vector<array> LogicalNot::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -1512,8 +1375,7 @@ std::pair<std::vector<array>, std::vector<int>> LogicalNot::vmap(
std::vector<array> LogicalAnd::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
assert(primals.size() == 2);
std::vector<array> vjps = {zeros_like(cotangents[0], stream())};
if (argnums.size() > 1) {
@@ -1544,8 +1406,7 @@ std::pair<std::vector<array>, std::vector<int>> LogicalAnd::vmap(
std::vector<array> LogicalOr::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
assert(primals.size() == 2);
std::vector<array> vjps = {zeros_like(cotangents[0], stream())};
if (argnums.size() > 1) {
@@ -1577,8 +1438,7 @@ std::pair<std::vector<array>, std::vector<int>> LogicalOr::vmap(
std::vector<array> LogAddExp::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
auto a = primals[0];
auto b = primals[1];
auto s = sigmoid(subtract(a, b, stream()), stream());
@@ -1623,8 +1483,7 @@ std::pair<std::vector<array>, std::vector<int>> LogAddExp::vmap(
std::vector<array> Matmul::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
auto& cotan = cotangents[0];
std::vector<int> reorder(cotan.ndim());
@@ -1647,8 +1506,7 @@ std::vector<array> Matmul::vjp(
std::vector<array> Maximum::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
auto& a = primals[0];
auto& b = primals[1];
std::vector<array> vjps;
@@ -1689,8 +1547,7 @@ std::pair<std::vector<array>, std::vector<int>> Maximum::vmap(
std::vector<array> Minimum::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
auto& a = primals[0];
auto& b = primals[1];
std::vector<array> vjps;
@@ -1744,8 +1601,7 @@ std::vector<array> Multiply::jvp(
std::vector<array> Multiply::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(multiply(primals[1 - arg], cotangents[0], stream()));
@@ -1763,8 +1619,7 @@ std::pair<std::vector<array>, std::vector<int>> Multiply::vmap(
std::vector<array> Negative::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -1795,8 +1650,7 @@ std::pair<std::vector<array>, std::vector<int>> NotEqual::vmap(
std::vector<array> NotEqual::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
vjps.push_back(zeros_like(primals[arg], stream()));
@@ -1815,8 +1669,7 @@ std::vector<array> NotEqual::jvp(
std::vector<array> Pad::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
assert(argnums.size() == 1 && argnums[0] == 0);
auto& cotan = cotangents[0];
@@ -1864,8 +1717,7 @@ bool Pad::is_equivalent(const Primitive& other) const {
std::vector<array> Partition::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -1897,8 +1749,7 @@ bool Partition::is_equivalent(const Primitive& other) const {
std::vector<array> Power::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg == 0) {
@@ -1910,7 +1761,10 @@ std::vector<array> Power::vjp(
primals[1],
stream()));
} else {
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
vjps.push_back(multiply(
log(primals[0], stream()),
power(primals[0], primals[1], stream()),
stream()));
}
vjps.back() = multiply(cotangents[0], vjps.back(), stream());
}
@@ -1921,13 +1775,12 @@ std::vector<array> Power::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto output = power(primals[0], primals[1], stream());
auto grads = vjp(primals, tangents, argnums, {output});
auto jvp = vjp(primals, {tangents[0]}, {argnums[0]});
if (argnums.size() > 1) {
return {add(grads[0], grads[1], stream())};
} else {
return grads;
jvp[0] =
add(jvp[0], vjp(primals, {tangents[1]}, {argnums[1]})[0], stream());
}
return jvp;
}
std::pair<std::vector<array>, std::vector<int>> Power::vmap(
@@ -1946,8 +1799,7 @@ std::pair<std::vector<array>, std::vector<int>> QuantizedMatmul::vmap(
std::vector<array> QuantizedMatmul::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
// We rely on the fact that w is always 2D so transpose is simple
@@ -2050,8 +1902,7 @@ std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
std::vector<array> Reshape::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
assert(argnums[0] == 0);
@@ -2076,8 +1927,7 @@ bool Reshape::is_equivalent(const Primitive& other) const {
std::vector<array> Reduce::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
const std::vector<int>& argnums) {
auto in = primals[0];
std::vector<int> shape = in.shape();
@@ -2147,10 +1997,15 @@ std::vector<array> Reduce::vjp(
}
} else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) {
auto out = outputs[0];
if (out.ndim() != in.ndim()) {
out = expand_dims(out, axes_, stream());
array (*op)(const array&, const std::vector<int>&, bool, StreamOrDevice);
if (reduce_type_ == Reduce::Min) {
op = min;
} else {
op = max;
}
auto out = op(in, axes_, true, stream());
auto mask = equal(in, out, stream());
auto normalizer = sum(mask, axes_, true, stream());
auto cotan_reshape = reshape(cotan, shape, stream());
@@ -2166,36 +2021,7 @@ std::vector<array> Reduce::vjp(
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto reduce_axes = axes_;
for (auto& rax : reduce_axes) {
if (rax >= ax) {
rax++;
}
}
auto& in = inputs[0];
std::vector<array> out;
switch (reduce_type_) {
case Reduce::And:
out.push_back(all(in, reduce_axes, true, stream()));
break;
case Reduce::Or:
out.push_back(any(in, reduce_axes, true, stream()));
break;
case Reduce::Sum:
out.push_back(sum(in, reduce_axes, true, stream()));
break;
case Reduce::Prod:
out.push_back(prod(in, reduce_axes, true, stream()));
break;
case Reduce::Min:
out.push_back(min(in, reduce_axes, true, stream()));
break;
case Reduce::Max:
out.push_back(max(in, reduce_axes, true, stream()));
break;
}
return {out, axes};
throw std::runtime_error("Reduce::vmap not yet implemented.");
}
bool Reduce::is_equivalent(const Primitive& other) const {
@@ -2206,8 +2032,7 @@ bool Reduce::is_equivalent(const Primitive& other) const {
std::vector<array> Round::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -2251,8 +2076,7 @@ std::pair<std::vector<array>, std::vector<int>> Scan::vmap(
std::vector<array> Scan::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums[0] == 0);
@@ -2260,7 +2084,7 @@ std::vector<array> Scan::vjp(
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
} else if (reduce_type_ == Scan::Prod) {
// TODO: Make it numerically stable when we introduce where()
auto prod = outputs[0];
auto prod = cumprod(primals[0], axis_, reverse_, inclusive_, stream());
auto partial_grads = multiply(prod, cotangents[0], stream());
auto accum_grads =
cumsum(partial_grads, axis_, !reverse_, inclusive_, stream());
@@ -2301,20 +2125,16 @@ bool Scatter::is_equivalent(const Primitive& other) const {
std::vector<array> Scatter::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
const std::vector<int>& argnums) {
switch (reduce_type_) {
case Scatter::None:
case Scatter::Sum:
case Scatter::Max:
case Scatter::Min:
break;
default:
throw std::runtime_error(
"[scatter] VJP not implemented for scatter_prod");
"[scatter] VJP implemented only for scatter and scatter_add");
}
const array& result = outputs[0];
const array& values = primals[0];
const array& updates = primals.back();
const std::vector<array> indices(primals.begin() + 1, primals.end() - 1);
@@ -2337,12 +2157,6 @@ std::vector<array> Scatter::vjp(
// The input array values are kept so they all get gradients
vjps.push_back(cotangents[0]);
break;
case Scatter::Max:
case Scatter::Min: {
auto mask = where(result == values, array({1}), array({0}));
vjps.push_back(multiply(cotangents[0], mask));
break;
}
default:
// Should never reach here
throw std::invalid_argument("");
@@ -2360,20 +2174,6 @@ std::vector<array> Scatter::vjp(
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
break;
}
case Scatter::Max:
case Scatter::Min: {
auto slice_sizes = cotangents[0].shape();
for (auto ax : axes_) {
slice_sizes[ax] = 1;
}
auto gathered_cotan =
gather(cotangents[0], indices, axes_, slice_sizes, stream());
auto gathered_result =
gather(result, indices, axes_, slice_sizes, stream());
vjps.push_back(
multiply(gathered_cotan, gathered_result == updates, stream()));
break;
}
default: {
// Should never reach here
throw std::invalid_argument("");
@@ -2397,12 +2197,8 @@ std::vector<array> Scatter::jvp(
std::vector<array> Sigmoid::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto& s = outputs[0];
auto sprime =
multiply(s, subtract(array(1.0f, s.dtype()), s, stream()), stream());
return {multiply(cotangents[0], sprime, stream())};
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
std::vector<array> Sigmoid::jvp(
@@ -2428,8 +2224,7 @@ std::pair<std::vector<array>, std::vector<int>> Sigmoid::vmap(
std::vector<array> Sign::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -2453,8 +2248,7 @@ std::pair<std::vector<array>, std::vector<int>> Sign::vmap(
std::vector<array> Sin::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -2478,8 +2272,7 @@ std::pair<std::vector<array>, std::vector<int>> Sin::vmap(
std::vector<array> Sinh::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -2517,8 +2310,7 @@ std::pair<std::vector<array>, std::vector<int>> Slice::vmap(
std::vector<array> Slice::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
// Check inputs
assert(primals.size() == 1);
@@ -2617,15 +2409,8 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
std::vector<array> Softmax::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
assert(primals.size() == 1);
assert(cotangents.size() == 1);
auto& s = outputs[0];
auto sv = multiply(s, cotangents[0], stream());
return {subtract(
sv,
multiply(s, sum(sv, std::vector<int>{-1}, true, stream()), stream()))};
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
std::vector<array> Softmax::jvp(
@@ -2653,8 +2438,7 @@ std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
std::vector<array> Sort::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -2674,38 +2458,10 @@ bool Sort::is_equivalent(const Primitive& other) const {
return axis_ == r_other.axis_;
}
std::pair<std::vector<array>, std::vector<int>> Split::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {
{split(inputs[0], indices_, axis_ + (axes[0] <= axis_), stream())}, axes};
}
std::vector<array> Split::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
return {concatenate(cotangents, axis_, stream())};
}
std::vector<array> Split::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
return split(tangents[0], indices_, axis_, stream());
}
bool Split::is_equivalent(const Primitive& other) const {
const Split& s_other = static_cast<const Split&>(other);
return axis_ == s_other.axis_ && indices_ == s_other.indices_;
}
std::vector<array> Square::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -2732,34 +2488,29 @@ std::pair<std::vector<array>, std::vector<int>> Square::vmap(
std::vector<array> Sqrt::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
assert(primals.size() == 1);
assert(cotangents.size() == 1);
auto dtype = primals[0].dtype();
if (recip_) {
auto one_over_x_root_x = divide(outputs[0], primals[0], stream());
return {multiply(
multiply(array(-0.5, dtype), cotangents[0], stream()),
one_over_x_root_x,
stream())};
} else {
return {divide(
multiply(array(0.5, dtype), cotangents[0], stream()),
outputs[0],
stream())};
}
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
std::vector<array> Sqrt::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(tangents.size() == 1);
auto dtype = primals[0].dtype();
if (recip_) {
return vjp(primals, tangents, argnums, {rsqrt(primals[0], stream())});
} else {
return vjp(primals, tangents, argnums, {sqrt(primals[0], stream())});
auto one_over_x_root_x =
divide(rsqrt(primals[0], stream()), primals[0], stream());
return {multiply(
multiply(array(-0.5, dtype), tangents[0], stream()),
one_over_x_root_x,
stream())};
}
return {divide(
multiply(array(0.5, dtype), tangents[0], stream()),
sqrt(primals[0], stream()),
stream())};
}
std::pair<std::vector<array>, std::vector<int>> Sqrt::vmap(
@@ -2787,8 +2538,7 @@ std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
std::vector<array> Subtract::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
std::vector<array> vjps;
for (auto arg : argnums) {
auto vjp = cotangents[0];
@@ -2825,8 +2575,7 @@ std::pair<std::vector<array>, std::vector<int>> Subtract::vmap(
std::vector<array> Tan::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -2851,8 +2600,7 @@ std::pair<std::vector<array>, std::vector<int>> Tan::vmap(
std::vector<array> Tanh::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
return jvp(primals, cotangents, argnums);
}
@@ -2877,8 +2625,7 @@ std::pair<std::vector<array>, std::vector<int>> Tanh::vmap(
std::vector<array> Transpose::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
std::vector<int> iaxes(axes_.size());

View File

@@ -1,4 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#pragma once
@@ -21,8 +21,7 @@
std::vector<array> vjp( \
const std::vector<array>& primals, \
const std::vector<array>& cotangents, \
const std::vector<int>& argnums, \
const std::vector<array>& outputs) override;
const std::vector<int>& argnums) override;
#define DEFINE_PRINT(PRIMITIVE) \
void print(std::ostream& os) override { \
@@ -79,8 +78,7 @@ class Primitive {
virtual std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs);
const std::vector<int>& argnums);
/**
* The primitive must know how to vectorize itself across
@@ -171,29 +169,6 @@ class Add : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class AddMM : public UnaryPrimitive {
public:
explicit AddMM(Stream stream, float alpha, float beta)
: UnaryPrimitive(stream), alpha_(alpha), beta_(beta){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(AddMM)
bool is_equivalent(const Primitive& other) const override;
private:
const float alpha_;
const float beta_;
};
class Arange : public UnaryPrimitive {
public:
explicit Arange(Stream stream, double start, double stop, double step)
@@ -341,7 +316,6 @@ class ArgReduce : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_PRINT(ArgReduce)
bool is_equivalent(const Primitive& other) const override;
@@ -490,8 +464,7 @@ class Convolution : public UnaryPrimitive {
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
const std::vector<int>& argnums) override;
DEFINE_PRINT(Convolution)
bool is_equivalent(const Primitive& other) const override;
@@ -553,60 +526,6 @@ class Cosh : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class CustomVJP : public Primitive {
public:
explicit CustomVJP(
Stream stream,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun)
: Primitive(stream), vjp_fun_(std::move(fun)) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(CustomVJP);
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)>
vjp_fun_;
};
class Depends : public Primitive {
public:
explicit Depends(Stream stream) : Primitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(Depends);
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};
class Divide : public UnaryPrimitive {
public:
explicit Divide(Stream stream) : UnaryPrimitive(stream){};
@@ -1000,8 +919,7 @@ class Matmul : public UnaryPrimitive {
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
const std::vector<int>& argnums) override;
DEFINE_PRINT(Matmul)
DEFINE_DEFAULT_IS_EQUIVALENT()
@@ -1235,8 +1153,7 @@ class Reduce : public UnaryPrimitive {
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
const std::vector<int>& argnums) override;
void print(std::ostream& os) override {
switch (reduce_type_) {
@@ -1504,28 +1421,6 @@ class Sort : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Split : public Primitive {
public:
explicit Split(Stream stream, const std::vector<int>& indices, int axis)
: Primitive(stream), indices_(indices), axis_(axis){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Split)
bool is_equivalent(const Primitive& other) const override;
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
std::vector<int> indices_;
int axis_;
};
class Square : public UnaryPrimitive {
public:
explicit Square(Stream stream) : UnaryPrimitive(stream){};
@@ -1657,20 +1552,4 @@ class Transpose : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
/* QR Factorization primitive. */
class QRF : public Primitive {
public:
explicit QRF(Stream stream) : Primitive(stream){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(QRF)
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};
} // namespace mlx::core

View File

@@ -2,7 +2,6 @@
#pragma once
#include <chrono>
#include <optional>
#include "mlx/array.h"
@@ -19,18 +18,12 @@ class KeySequence {
// static default
static KeySequence& default_() {
static KeySequence ks(get_current_time_seed());
static KeySequence ks(0);
return ks;
}
private:
array key_;
static uint64_t get_current_time_seed() {
auto now = std::chrono::system_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch())
.count();
}
};
/** Get a PRNG key from a seed. */

View File

@@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <future>
#include <map>
#include <numeric>
#include <set>
#include <sstream>
@@ -34,11 +35,187 @@ class Synchronizer : public Primitive {
// are currently under a function transformation.
int detail::InTracing::tracing_counter{0};
void eval(const std::vector<array>& outputs) {
std::function<void(const array&, bool)> recurse;
void simplify(const std::vector<array>& outputs) {
// Some notes about how this function works
//
// Step 1: Traverse the graph and build a tape. During the graph
// traversal we:
// - Build a map of inputs to their parents.
// - Record scalar inputs in a map in order to fuse them.
// Step 2: Process the tape. A node in the tape has inputs and outputs.
// - Scalar inputs are replaced with their canonical scalar
// - We check each inputs output nodes. Every output node that matches
// the current node gets fused into the current node.
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
parents_map;
// Helpers to identify identical scalars
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
auto is_scalar = [](const array& a) {
return a.is_evaled() && a.ndim() == 0;
};
auto get_scalar_rep = [](const array& a) {
uint64_t v = 0;
int dtype;
switch (a.dtype().size) {
case 1:
v = *a.data<uint8_t>();
break;
case 4:
v = *a.data<uint32_t>();
break;
case 8:
v = *a.data<uint64_t>();
break;
}
return std::make_pair(v, a.dtype().val);
};
// DFS the graph to build the tape, and log parents and scalars
recurse = [&](const array& a) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
for (int i = 0; i < a.inputs().size(); i++) {
auto& in = a.inputs()[i];
parents_map[in.id()].push_back({a, i});
for (auto& s : a.siblings()) {
parents_map[in.id()].push_back({s, i});
}
recurse(in);
}
cache.insert(id);
for (auto& s : a.siblings()) {
cache.insert(s.id());
}
tape.push(a);
if (is_scalar(a)) {
scalars.insert({get_scalar_rep(a), a});
}
};
for (auto& a : outputs) {
recurse(a);
}
// Helper that fuses two arrays in the graph by setting the parents of the
// source to point to the destination
auto fuse = [&](array& dst, array& src) {
// Canonicalize the order of the primitives outputs
auto sources = src.outputs();
auto dests = dst.outputs();
// For each src parent, point it to the corresponding dest
for (int i = 0; i < sources.size(); ++i) {
auto src_parents = parents_map.find(sources[i].id());
if (src_parents == parents_map.end()) {
continue;
}
auto& pairs = parents_map[dests[i].id()];
for (auto& parent : src_parents->second) {
parent.first.inputs()[parent.second] = dests[i];
pairs.push_back(parent);
}
// Remove the source from the map to avoid fusing with it again
parents_map.erase(src_parents);
}
};
// Depth-1 array equivalence check.
auto array_equivalent = [](const array& a, const array& b) {
if (!a.has_primitive() || !b.has_primitive()) {
return false;
}
if (a.primitive_id() == b.primitive_id()) {
return false;
}
const auto& pa = a.primitive();
const auto& pb = b.primitive();
if (typeid(pa) != typeid(pb)) {
return false;
}
if (a.inputs().size() != b.inputs().size()) {
return false;
}
for (int i = 0; i < a.inputs().size(); i++) {
if (a.inputs()[i].id() != b.inputs()[i].id()) {
return false;
}
}
return pa.is_equivalent(pb);
};
// Walk the graph
while (!tape.empty()) {
auto arr = std::move(tape.front());
tape.pop();
// Check if we can fuse scalars
if (is_scalar(arr)) {
auto scalar = scalars.find(get_scalar_rep(arr));
if (scalar->second.id() != arr.id()) {
fuse(scalar->second, arr);
arr = scalar->second;
}
}
// Helper to check if we can fuse the parents of the
// given array
auto maybe_fuse_parents = [&](auto& a) {
auto parents = parents_map.find(a.id());
if (parents != parents_map.end()) {
auto N = parents->second.size();
std::vector<bool> mask(N, false);
for (int i = 0; i < N; i++) {
if (mask[i]) {
continue;
}
for (int j = i + 1; j < N; j++) {
if (mask[j]) {
continue;
}
auto& src = parents->second[j].first;
auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) {
fuse(dst, src);
mask[j] = true;
}
}
}
}
};
maybe_fuse_parents(arr);
for (auto& s : arr.siblings()) {
maybe_fuse_parents(s);
}
}
}
void eval(const std::vector<array>& outputs) {
std::function<int(const array&)> recurse;
std::unordered_map<std::uintptr_t, int> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
struct ArrayWithPriority {
int depth;
int order;
array x;
};
auto cmp = [](const ArrayWithPriority& a, const ArrayWithPriority& b) {
return (a.depth > b.depth) || (a.depth == b.depth && a.order > b.order);
};
std::priority_queue<
ArrayWithPriority,
std::vector<ArrayWithPriority>,
decltype(cmp)>
tape(cmp);
int order;
// Make an effort to choose a good output stream
Stream stream = default_stream(default_device());
@@ -52,77 +229,44 @@ void eval(const std::vector<array>& outputs) {
auto synchronizer =
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
recurse = [&](const array& a, bool largest_branch_first) {
recurse = [&](const array& a) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
if (auto it = cache.find(id); it != cache.end()) {
return it->second;
}
// If the input is being computed on a different stream, we need to manage
// the dependency.
auto check_dependency = [&](const array& in) {
int input_depth = 0;
for (auto in : a.inputs()) {
input_depth = std::max(input_depth, recurse(in));
// If one of the inputs is being computed on a different
// stream, we need to manage the dependency.
if (!in.is_evaled()) {
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.primitive_id(), std::shared_future<void>{}});
}
}
};
// Recurse to the largest or smallest branch first.
size_t num_inputs = a.inputs().size();
if (num_inputs == 1) {
auto& in = a.inputs()[0];
recurse(in, true);
check_dependency(in);
} else if (num_inputs == 2) {
auto depth_1 = a.inputs()[0].graph_depth();
auto depth_2 = a.inputs()[1].graph_depth();
auto& in1 = a.inputs()[static_cast<int>(
!((depth_1 > depth_2) == largest_branch_first))];
auto& in2 = a.inputs()[static_cast<int>(
((depth_1 > depth_2) == largest_branch_first))];
recurse(in1, true);
check_dependency(in1);
recurse(in2, true);
check_dependency(in2);
} else if (num_inputs > 2) {
std::vector<int> recursion_order(a.inputs().size());
std::iota(recursion_order.begin(), recursion_order.end(), 0);
std::sort(
recursion_order.begin(),
recursion_order.end(),
[&a, largest_branch_first](int i, int j) {
auto depth_i = a.inputs()[i].graph_depth();
auto depth_j = a.inputs()[j].graph_depth();
return largest_branch_first ? depth_i > depth_j : depth_j < depth_i;
});
for (int idx : recursion_order) {
auto& in = a.inputs()[idx];
recurse(in, true);
check_dependency(in);
}
}
cache.insert(id);
cache.insert({id, input_depth + 1});
for (auto& s : a.siblings()) {
cache.insert(s.id());
cache.insert({s.id(), input_depth + 1});
}
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
if (!a.has_primitive()) {
throw std::invalid_argument(
"[eval] Attempting to eval an array without a primitive.");
}
tape.push(a);
tape.push({input_depth + 1, order++, a});
}
return input_depth + 1;
};
recurse(synchronizer, false);
recurse(synchronizer);
uintptr_t synch_id = synchronizer.primitive_id();
deps.insert({synch_id, std::shared_future<void>{}});
std::vector<std::shared_ptr<std::promise<void>>> ps;
while (!tape.empty()) {
auto arr = std::move(tape.front());
auto val = std::move(tape.top());
auto arr = std::move(val.x);
tape.pop();
if (arr.is_evaled()) {
if (!arr.is_tracer() && arr.has_primitive()) {
@@ -134,7 +278,6 @@ void eval(const std::vector<array>& outputs) {
auto stream = arr.primitive().stream();
std::vector<std::shared_future<void>> arr_deps;
for (auto& in : arr.inputs()) {
// TODO that's a bug
if (auto it = deps.find(in.primitive_id()); it != deps.end()) {
arr_deps.push_back(it->second);
}
@@ -165,6 +308,9 @@ void eval(const std::vector<array>& outputs) {
arr.primitive().eval_cpu(arr.inputs(), outputs);
if (!arr.is_tracer()) {
arr.detach();
for (auto s : arr.siblings()) {
s.detach();
}
}
if (p) {
p->set_value();
@@ -210,21 +356,12 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
}
}
if (cotan_index >= cotans.size()) {
std::ostringstream msg;
msg << "[vjp] Number of outputs to compute gradients for ("
<< outputs.size() << ") does not match number of cotangents ("
<< cotans.size() << ").";
throw std::invalid_argument(msg.str());
throw std::invalid_argument(
"[vjp] Number of outputs with gradient does not match number of cotangents.");
}
if (out.shape() != cotans[cotan_index].shape()) {
std::ostringstream msg;
msg << "[vjp] Output shape " << out.shape()
<< " does not match cotangent shape " << cotans[cotan_index].shape()
<< ".";
if (outputs.size() == 1 && out.size() == 1) {
msg << " If you are using grad your function must return a scalar.";
}
throw std::invalid_argument(msg.str());
throw std::invalid_argument(
"[vjp] Output shape does not match shape of cotangent.");
}
output_cotan_pairs.emplace_back(i, cotan_index++);
}
@@ -322,7 +459,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
}
}
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums);
// Accumulate the vector-jacobian products for each input
for (int i = 0; i < argnums.size(); ++i) {
auto in_id = a.inputs()[argnums[i]].id();
@@ -548,8 +685,9 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
"[vmap] The number of in axes must match the number of inputs.");
}
// Some error checking and get the vmap axis size
size_t vmap_ax_size;
// Run the function on placeholder inputs
// to get the original graph
std::vector<array> s_inputs;
for (int i = 0; i < inputs.size(); ++i) {
if (in_axes[i] != -1) {
if (inputs[i].ndim() == 0) {
@@ -562,26 +700,7 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
<< inputs[i].ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
vmap_ax_size = inputs[i].shape(in_axes[i]);
}
}
// Check that all vmapped axes have the same size
for (int i = 0; i < inputs.size(); ++i) {
if (in_axes[i] != -1) {
if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) {
std::ostringstream msg;
msg << "[vmap] Inconsistent axis sizes: " << in_ax << " and "
<< vmap_ax_size << ".";
throw std::invalid_argument(msg.str());
}
}
}
// Run the function on placeholder inputs
// to get the original graph
std::vector<array> s_inputs;
for (int i = 0; i < inputs.size(); ++i) {
if (in_axes[i] != -1) {
std::vector<int> shape = inputs[i].shape();
shape.erase(shape.begin() + in_axes[i]);
array in(shape, inputs[i].dtype(), nullptr, {});
@@ -767,58 +886,4 @@ std::function<array(const array&)> vmap(
return [vfun](const array& a) { return vfun({a})[0]; };
}
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp) {
return [fun = std::move(fun),
fun_vjp = std::move(fun_vjp)](const std::vector<array>& args) {
// Compute the outputs
auto outputs = fun(args);
for (auto& out : outputs) {
out = stop_gradient(out);
}
// Prepare the inputs to the primitive
// We also add the outputs to the primitive so that it can "run" the forward
// pass.
std::vector<array> inputs = args;
inputs.insert(inputs.end(), outputs.begin(), outputs.end());
// Compute the stream. Maybe do it in a smarter way at some point in the
// future.
Stream s = (outputs[0].has_primitive()) ? outputs[0].primitive().stream()
: default_stream(default_device());
// Make the output info
std::vector<std::vector<int>> shapes;
std::vector<Dtype> dtypes;
for (const auto& out : outputs) {
shapes.emplace_back(out.shape());
dtypes.emplace_back(out.dtype());
}
return array::make_arrays(
shapes,
dtypes,
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
inputs);
};
}
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
std::function<std::vector<array>(const std::vector<array>&)> fun) {
auto vjp_fun = [fun](
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<array>& outputs) -> std::vector<array> {
auto [__, vjps] = vjp(fun, depends(primals, outputs), cotangents);
return vjps;
};
return custom_vjp(fun, vjp_fun);
}
} // namespace mlx::core

View File

@@ -1,25 +1,18 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/array.h"
#include "array.h"
namespace mlx::core {
// Compile takes a function and returns a new function
std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
/** Fuse equivalent arrays to avoid duplicate execution. */
void simplify(const std::vector<array>& outputs);
/** Globally disable compilation.
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
* be used to disable compilation.
*/
void disable_compile();
/** Globally enable compilation.
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
*/
void enable_compile();
template <typename... Arrays>
void simplify(Arrays... outputs) {
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
}
void eval(const std::vector<array>& outputs);
@@ -191,22 +184,4 @@ std::function<std::vector<array>(const std::vector<array>&)> vmap(
const std::vector<int>& in_axes = {},
const std::vector<int>& out_axes = {});
/**
* Return the results of calling fun with args but if their vjp is computed it
* will be computed by fun_vjp.
*/
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp);
/**
* Checkpoint the gradient of a function. Namely, discard all intermediate
* state and recalculate it when we need to compute the gradient.
*/
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
std::function<std::vector<array>(const std::vector<array>&)> fun);
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More