Merge branch 'ml-explore:main' into main

This commit is contained in:
Luca Arnaboldi 2024-03-04 10:57:32 +01:00 committed by GitHub
commit c02602a4a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
94 changed files with 5858 additions and 1575 deletions

View File

@ -237,6 +237,14 @@ workflows:
jobs:
- mac_build_and_test
- linux_build_and_test
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
filters:
tags:

View File

@ -11,7 +11,7 @@ MLX was developed with contributions from the following individuals:
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops.
@ -256,4 +256,4 @@ Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.

View File

@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.3.0)
set(MLX_VERSION 0.5.0)
endif()
# --------------------- Processor tests -------------------------
@ -67,8 +67,6 @@ if (MLX_BUILD_METAL AND NOT METAL_LIB)
set(MLX_BUILD_METAL OFF)
elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources")
add_compile_definitions(_METAL_)
# Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION

View File

@ -11,10 +11,12 @@ brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy.
MLX also has a fully featured C++ API, which closely mirrors the Python API.
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
that closely follow PyTorch to simplify building more complex models.
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
- **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization,

View File

@ -73,6 +73,7 @@ void time_unary_ops() {
void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
@ -84,7 +85,9 @@ void time_binary_ops() {
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(where, condition, a, b, device);
condition = array({true});
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
@ -93,7 +96,9 @@ void time_binary_ops() {
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device);
condition = broadcast_to(array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
@ -101,6 +106,7 @@ void time_binary_ops() {
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
}
void time_strided_ops() {

View File

@ -0,0 +1,129 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding)
f_pt = make_pt_conv_2D(strides, padding)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
)
for dtype in dtypes:
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
for N, H, W, C, kH, kW, O, strides, padding in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@ -7,12 +7,14 @@ import torch
from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx):
dst[idx] = x
dst[*idx] = x
mx.eval(dst)
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
idx = []
for idx_shape in idx_shapes:
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
x = mx.random.normal(x_shape).astype(mx.float32)
dst = mx.random.normal(dst_shape).astype(mx.float32)
@ -20,13 +22,15 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
print(f"MLX: {runtime:.3f}ms")
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def gather(dst, x, idx, device):
dst[idx] = x
dst[*idx] = x
if device == torch.device("mps"):
torch.mps.synchronize()
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
idx = []
for idx_shape in idx_shapes:
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
@ -45,9 +49,45 @@ if __name__ == "__main__":
else:
device = torch.device("mps")
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
dst_shapes = [
(10, 64),
(100_000, 64),
(1_000_000, 64),
(100_000,),
(2_000_00,),
(20_000_000,),
(10000, 64),
(100, 64),
(100, 10_000, 64),
(10, 100, 100, 21),
(1_000, 1_000, 10),
]
idx_shapes = [
[(1_000_000,)],
[(1_000_000,)],
[(100_000,)],
[(1_000_000,)],
[(20_000_000,)],
[(20_000_000,)],
[(1000000,)],
[(10000000,)],
[(1_000,)],
[(10_000,)],
[(1_000,), (1_000,)],
]
x_shapes = [
(1_000_000, 64),
(1_000_000, 64),
(100_000, 64),
(1_000_000,),
(20_000_000,),
(20_000_000,),
(1000000, 64),
(10000000, 64),
(1_000, 10_000, 64),
(10_000, 100, 100, 21),
(1_000, 10),
]
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.2 KiB

After

Width:  |  Height:  |  Size: 76 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

View File

@ -49,10 +49,12 @@ html_theme_options = {
"repository_url": "https://github.com/ml-explore/mlx",
"use_repository_button": True,
"navigation_with_keys": False,
"logo": {
"image_light": "_static/mlx_logo.png",
"image_dark": "_static/mlx_logo_dark.png",
},
}
html_logo = "_static/mlx_logo.png"
# -- Options for HTMLHelp output ---------------------------------------------

View File

@ -64,6 +64,7 @@ are the CPU and GPU.
python/transforms
python/fft
python/linalg
python/metal
python/nn
python/optimizers
python/tree_utils

14
docs/src/python/metal.rst Normal file
View File

@ -0,0 +1,14 @@
Metal
=====
.. currentmodule:: mlx.core.metal
.. autosummary::
:toctree: _autosummary
is_available
get_active_memory
get_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit

View File

@ -12,13 +12,24 @@ simple functions.
:toctree: _autosummary_functions
:template: nn-module-template.rst
elu
gelu
gelu_approx
gelu_fast_approx
glu
hardswish
leaky_relu
log_sigmoid
log_softmax
mish
prelu
relu
relu6
selu
softshrink
sigmoid
silu
softmax
softplus
softshrink
step
tanh

View File

@ -40,3 +40,4 @@ Layers
Softshrink
Step
Transformer
Upsample

View File

@ -35,6 +35,7 @@ Operations
convolve
conv1d
conv2d
conv_general
cos
cosh
dequantize
@ -56,6 +57,7 @@ Operations
greater_equal
identity
inner
isclose
isnan
isposinf
isneginf
@ -120,6 +122,8 @@ Operations
tan
tanh
tensordot
tile
topk
transpose
tri
tril

View File

@ -8,6 +8,8 @@ Schedulers
.. autosummary::
:toctree: _autosummary
step_decay
exponential_decay
cosine_decay
exponential_decay
join_schedules
linear_schedule
step_decay

View File

@ -64,6 +64,7 @@ DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)

View File

@ -1,6 +1,9 @@
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(COMPILER ${CMAKE_C_COMPILER})
set(CLANG TRUE)
else()
set(COMPILER ${CMAKE_CXX_COMPILER})
endif()
add_custom_command(
@ -8,16 +11,16 @@ add_custom_command(
COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_CXX_COMPILER}
${CMAKE_SOURCE_DIR}
${COMPILER}
${PROJECT_SOURCE_DIR}
${CLANG}
DEPENDS make_compiled_preamble.sh
compiled_preamble.h
${CMAKE_SOURCE_DIR}/mlx/types/half_types.h
${CMAKE_SOURCE_DIR}/mlx/types/fp16.h
${CMAKE_SOURCE_DIR}/mlx/types/bf16.h
${CMAKE_SOURCE_DIR}/mlx/types/complex.h
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
ops.h
)
@ -43,6 +46,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp

View File

@ -9,7 +9,7 @@ namespace mlx::core {
namespace {
enum BinaryOpType {
enum class BinaryOpType {
ScalarScalar,
ScalarVector,
VectorScalar,
@ -20,17 +20,17 @@ enum BinaryOpType {
BinaryOpType get_binary_op_type(const array& a, const array& b) {
BinaryOpType bopt;
if (a.data_size() == 1 && b.data_size() == 1) {
bopt = ScalarScalar;
bopt = BinaryOpType::ScalarScalar;
} else if (a.data_size() == 1 && b.flags().contiguous) {
bopt = ScalarVector;
bopt = BinaryOpType::ScalarVector;
} else if (b.data_size() == 1 && a.flags().contiguous) {
bopt = VectorScalar;
bopt = BinaryOpType::VectorScalar;
} else if (
a.flags().row_contiguous && b.flags().row_contiguous ||
a.flags().col_contiguous && b.flags().col_contiguous) {
bopt = VectorVector;
bopt = BinaryOpType::VectorVector;
} else {
bopt = General;
bopt = BinaryOpType::General;
}
return bopt;
}
@ -42,11 +42,11 @@ void set_binary_op_output_data(
BinaryOpType bopt,
bool donate_with_move = false) {
switch (bopt) {
case ScalarScalar:
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
@ -61,7 +61,7 @@ void set_binary_op_output_data(
b.flags());
}
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
@ -76,7 +76,7 @@ void set_binary_op_output_data(
a.flags());
}
break;
case VectorVector:
case BinaryOpType::VectorVector:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
@ -97,7 +97,7 @@ void set_binary_op_output_data(
a.flags());
}
break;
case General:
case BinaryOpType::General:
if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) {
@ -424,25 +424,25 @@ void binary_op(
set_binary_op_output_data(a, b, out, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == ScalarScalar) {
if (bopt == BinaryOpType::ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == ScalarVector) {
if (bopt == BinaryOpType::ScalarVector) {
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}
// The full computation is vector scalar so delegate to the op
if (bopt == VectorScalar) {
if (bopt == BinaryOpType::VectorScalar) {
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}
// The full computation is vector vector so delegate to the op
if (bopt == VectorVector) {
if (bopt == BinaryOpType::VectorVector) {
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
@ -475,17 +475,17 @@ void binary_op(
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = VectorVector;
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = VectorScalar;
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = ScalarVector;
bopt = BinaryOpType::ScalarVector;
dim = d;
}
@ -495,20 +495,20 @@ void binary_op(
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = General;
bopt = BinaryOpType::General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case VectorVector:
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
break;
default:

View File

@ -260,14 +260,14 @@ void binary_op(
set_binary_op_output_data(a, b, out_b, bopt);
// The full computation is scalar scalar so call the base op once
if (bopt == ScalarScalar) {
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
op(*a.data<T>(), *b.data<T>());
return;
}
// The full computation is scalar vector so delegate to the op
if (bopt == ScalarVector) {
if (bopt == BinaryOpType::ScalarVector) {
opsv(
a.data<T>(),
b.data<T>(),
@ -278,7 +278,7 @@ void binary_op(
}
// The full computation is vector scalar so delegate to the op
if (bopt == VectorScalar) {
if (bopt == BinaryOpType::VectorScalar) {
opvs(
a.data<T>(),
b.data<T>(),
@ -289,7 +289,7 @@ void binary_op(
}
// The full computation is vector vector so delegate to the op
if (bopt == VectorVector) {
if (bopt == BinaryOpType::VectorVector) {
opvv(
a.data<T>(),
b.data<T>(),
@ -327,17 +327,17 @@ void binary_op(
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = VectorVector;
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = VectorScalar;
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = ScalarVector;
bopt = BinaryOpType::ScalarVector;
dim = d;
}
@ -347,20 +347,20 @@ void binary_op(
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = General;
bopt = BinaryOpType::General;
dim = ndim;
} else {
stride = strides[dim - 1];
}
switch (bopt) {
case VectorVector:
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
break;
default:

View File

@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <numeric>
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
@ -27,14 +28,16 @@ void slow_conv_1D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const T* start_wt_ptr = wt.data<T>();
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
@ -61,12 +64,15 @@ void slow_conv_1D(
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int ih = oh * wt_strides[0] - padding[0] + wh * wt_dilation[0];
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
if (ih >= 0 && ih < iH) {
auto ih_div = std::div(ih, in_dilation[0]);
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = 0; c < C; ++c) {
r += static_cast<float>(
in_ptr[ih * in_stride_H + c * in_stride_C]) *
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[c * wt_stride_C]);
} // c
@ -90,14 +96,16 @@ void slow_conv_2D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const T* st_wt_ptr = wt.data<T>();
const T* st_in_ptr = in.data<T>();
T* st_out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
@ -120,6 +128,8 @@ void slow_conv_2D(
const size_t out_stride_W = out.strides()[2];
const size_t out_stride_O = out.strides()[3];
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
auto pt_conv_no_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
@ -131,8 +141,10 @@ void slow_conv_2D(
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int ih = ih_base + wh * wt_dilation[0];
int iw = iw_base + ww * wt_dilation[1];
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
@ -153,25 +165,74 @@ void slow_conv_2D(
} // o
};
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int f_wgt_jump_h = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
base_h[i] = wh_base;
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
base_w[j] = ww_base;
}
auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int ih = ih_base + wh * wt_dilation[0];
int iw = iw_base + ww * wt_dilation[1];
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
@ -191,13 +252,17 @@ void slow_conv_2D(
};
int oH_border_0 = 0;
int oH_border_1 = (padding[0] + wt_strides[0] + 1) / wt_strides[0];
int oH_border_2 = (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0];
int oH_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 = (padding[1] + wt_strides[0] + 1) / wt_strides[1];
int oW_border_2 = (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1];
int oW_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
@ -246,15 +311,18 @@ void dispatch_slow_conv_1D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (in.dtype() == float32) {
return slow_conv_1D<float>(in, wt, out, padding, wt_strides, wt_dilation);
return slow_conv_1D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == float16) {
return slow_conv_1D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation);
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == bfloat16) {
return slow_conv_1D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation);
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
@ -267,15 +335,18 @@ void dispatch_slow_conv_2D(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (in.dtype() == float32) {
return slow_conv_2D<float>(in, wt, out, padding, wt_strides, wt_dilation);
return slow_conv_2D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == float16) {
return slow_conv_2D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation);
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == bfloat16) {
return slow_conv_2D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation);
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
@ -493,13 +564,16 @@ void conv_1D_cpu(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
if (wt_dilation[0] == 1) {
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation);
}
return dispatch_slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation);
return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
void conv_2D_cpu(
@ -508,8 +582,11 @@ void conv_2D_cpu(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
return dispatch_slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation);
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
} // namespace
@ -523,12 +600,26 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
// 2D convolution
if (in.ndim() == (2 + 2)) {
return conv_2D_cpu(
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
}
// 1D convolution
else if (in.ndim() == (1 + 2)) {
return conv_1D_cpu(
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
}
// Throw error
else {

View File

@ -87,6 +87,7 @@ DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Sin)

View File

@ -7,6 +7,10 @@
namespace mlx::core::detail {
namespace {
constexpr float inf = std::numeric_limits<float>::infinity();
} // namespace
typedef union {
int i;
float f;
@ -588,4 +592,11 @@ struct LogicalOr {
};
};
struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return condition ? x : y;
}
};
} // namespace mlx::core::detail

View File

@ -0,0 +1,72 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/backend/common/ternary.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename Op>
void select_op(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
switch (out.dtype()) {
case bool_:
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
break;
case uint8:
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
break;
case uint16:
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
break;
case uint32:
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
break;
case uint64:
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
break;
case int8:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
break;
case int16:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
break;
case int32:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
break;
case int64:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
break;
case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
break;
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
break;
case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
break;
}
}
} // namespace
void Select::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
const auto& condition = inputs[0];
const auto& a = inputs[1];
const auto& b = inputs[2];
select_op(condition, a, b, out, detail::Select());
}
} // namespace mlx::core

View File

@ -0,0 +1,226 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
namespace {
// TODO: Add support for more combinations of input types.
enum class TernaryOpType {
ScalarScalarScalar,
General,
};
TernaryOpType
get_ternary_op_type(const array& a, const array& b, const array& c) {
TernaryOpType topt;
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
topt = TernaryOpType::ScalarScalarScalar;
} else {
topt = TernaryOpType::General;
}
return topt;
}
void set_ternary_op_output_data(
const array& a,
const array& b,
const array& c,
array& out,
TernaryOpType topt,
bool donate_with_move = false) {
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::General:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
break;
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims1(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
for (size_t i = 0; i < out.size(); ++i) {
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[0];
b_idx += b.strides()[0];
c_idx += c.strides()[0];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims2(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[1];
b_idx += b.strides()[1];
c_idx += c.strides()[1];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims3(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[2];
b_idx += b.strides()[2];
c_idx += c.strides()[2];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dims4(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
size_t a_idx = 0;
size_t b_idx = 0;
size_t c_idx = 0;
size_t out_idx = 0;
for (size_t i = 0; i < a.shape()[0]; ++i) {
for (size_t j = 0; j < a.shape()[1]; ++j) {
for (size_t k = 0; k < a.shape()[2]; ++k) {
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
a_idx += a.strides()[3];
b_idx += b.strides()[3];
c_idx += c.strides()[3];
}
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
}
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
}
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op_dispatch_dims(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
switch (out.ndim()) {
case 1:
ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 2:
ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 3:
ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
case 4:
ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
return;
}
const T1* a_ptr = a.data<T1>();
const T2* b_ptr = b.data<T2>();
const T3* c_ptr = c.data<T3>();
U* dst = out.data<U>();
for (size_t i = 0; i < out.size(); i++) {
int a_idx = elem_to_loc(i, a.shape(), a.strides());
int b_idx = elem_to_loc(i, b.shape(), b.strides());
int c_idx = elem_to_loc(i, c.shape(), c.strides());
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
}
}
template <typename T1, typename T2, typename T3, typename U, typename Op>
void ternary_op(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt);
// The full computation is scalar-scalar-scalar so we call the base op once.
if (topt == TernaryOpType::ScalarScalarScalar) {
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
return;
}
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
}
} // namespace
} // namespace mlx::core

View File

@ -4,7 +4,7 @@ add_custom_command(
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_C_COMPILER}
${CMAKE_SOURCE_DIR}
${PROJECT_SOURCE_DIR}
DEPENDS make_compiled_preamble.sh
kernels/compiled_preamble.h
kernels/unary.h

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
@ -23,16 +23,6 @@ void* Buffer::raw_ptr() {
namespace metal {
static bool cache_enabled_ = true;
bool cache_enabled() {
return cache_enabled_;
}
void set_cache_enabled(bool enabled) {
cache_enabled_ = enabled;
}
namespace {
BufferCache::BufferCache(MTL::Device* device)
@ -158,9 +148,23 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
buffer_cache_(device_),
peak_allocated_size_(0),
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()),
max_pool_size_(block_limit_) {}
size_t MetalAllocator::set_cache_limit(size_t limit) {
std::swap(limit, max_pool_size_);
return limit;
};
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
std::swap(limit, block_limit_);
relaxed_ = relaxed;
gc_limit_ = std::min(
block_limit_,
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Metal doesn't like empty buffers
@ -175,10 +179,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Try the cache
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
size_t pool_size = get_cache_memory();
if (!buf) {
size_t mem_required = get_active_memory() + pool_size + size;
// If there is too much memory pressure, fail (likely causes a wait).
if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) {
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
return Buffer{nullptr};
}
@ -186,10 +192,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// If we have a lot of memory pressure, check if we can reclaim some memory
// from the cache
if (device_->currentAllocatedSize() + size >= gc_limit_) {
size_t min_bytes_to_free =
size + device_->currentAllocatedSize() - gc_limit_;
buffer_cache_.release_cached_buffers(min_bytes_to_free);
if (mem_required >= gc_limit_) {
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
}
// Allocate new buffer if needed
@ -198,15 +202,22 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
buf = device_->newBuffer(size, res_opt);
}
peak_allocated_size_ =
std::max(peak_allocated_size_, device_->currentAllocatedSize());
// Maintain the cache below the requested limit
if (pool_size >= max_pool_size_) {
auto thread_pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(pool_size - max_pool_size_);
}
active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_);
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (cache_enabled()) {
active_memory_ -= buf->length();
if (max_pool_size_ > 0) {
buffer_cache_.recycle_to_cache(buf);
} else {
buf->release();
@ -218,6 +229,22 @@ MetalAllocator& allocator() {
return allocator_;
}
size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit);
}
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed);
}
size_t get_active_memory() {
return allocator().get_active_memory();
}
size_t get_peak_memory() {
return allocator().get_peak_memory();
}
size_t get_cache_memory() {
return allocator().get_cache_memory();
}
} // namespace metal
} // namespace mlx::core

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@ -24,6 +24,9 @@ class BufferCache {
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
void release_cached_buffers(size_t min_bytes_to_free);
size_t pool_size() {
return pool_size_;
}
private:
struct BufferHolder {
@ -54,6 +57,17 @@ class MetalAllocator : public allocator::Allocator {
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
size_t get_active_memory() {
return active_memory_;
};
size_t get_peak_memory() {
return peak_memory_;
};
size_t get_cache_memory() {
return buffer_cache_.pool_size();
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
private:
MTL::Device* device_;
@ -64,9 +78,12 @@ class MetalAllocator : public allocator::Allocator {
BufferCache buffer_cache_;
// Allocation stats
size_t peak_allocated_size_;
size_t block_limit_;
size_t gc_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
size_t max_pool_size_;
bool relaxed_{true};
};
MetalAllocator& allocator();

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
@ -7,80 +7,72 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/conv_params.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/matmul.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
using namespace mlx::steel;
namespace mlx::core {
namespace {
void explicit_gemm_conv_1D_gpu(
template <int N>
void explicit_gemm_conv_ND_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<1>& conv_params) {
// Pad input
std::vector<int> padded_shape = {
conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.C};
array in_padded(padded_shape, in.dtype(), nullptr, {});
const MLXConvParams<N>& conv_params) {
// Prepare unfolding array
std::vector<int> unfolded_shape = {
static_cast<int>(out.size() / conv_params.O),
static_cast<int>(wt.size() / conv_params.O)};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
// Fill with zeros
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, in_unfolded, 1);
// Make strided view
std::vector<int> strided_shape = {
conv_params.N, conv_params.oS[0], conv_params.wS[0], conv_params.C};
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * conv_params.str[0],
in_padded.strides()[1],
in_padded.strides()[2]};
auto flags = in_padded.flags();
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x;
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
// Materialize strided view
std::vector<int> strided_reshape = {
conv_params.N * conv_params.oS[0], conv_params.wS[0] * conv_params.C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
std::vector<array> copies;
return steel_matmul(
s,
d,
/*a = */ in_strided,
/*a = */ in_unfolded,
/*b = */ wt,
/*c = */ out,
/*M = */ strided_reshape[0],
/*M = */ unfolded_shape[0],
/*N = */ conv_params.O,
/*K = */ strided_reshape[1],
/*K = */ unfolded_shape[1],
/*batch_size_out = */ 1,
/*a_cols = */ strided_reshape[1],
/*b_cols = */ strided_reshape[1],
/*a_cols = */ unfolded_shape[1],
/*b_cols = */ unfolded_shape[1],
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
@ -94,7 +86,9 @@ void conv_1D_gpu(
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ in.shape(0),
@ -105,24 +99,19 @@ void conv_1D_gpu(
/* const int oS[NDIM] = */ {out.shape(1)},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int dil[NDIM] = */ {wt_dilation[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
/* const int idil[NDIM] = */ {in_dilation[0]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
};
/* const int groups = */ 1,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
if (wt_dilation[0] == 1) {
explicit_gemm_conv_1D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to fallback conv
else {
throw std::invalid_argument("[conv_1D_gpu] Dilation needs to be 1.");
}
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
void slow_conv_2D_gpu(
@ -168,113 +157,262 @@ void implicit_gemm_conv_2D_gpu(
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
int bm = 32, bn = 32, bk = 16;
// Deduce implicit gemm size
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
// Determine block and warp tiles
int wm = 2, wn = 2;
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
int bk = 16;
if (implicit_N <= 16) {
bn = 8;
wm = 4;
wn = 1;
}
int tn = (implicit_N + bn - 1) / bn;
int tm = (implicit_M + bm - 1) / bm;
int swizzle_log = 0;
// Fix small channel specialization
int n_channel_specialization = 0;
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
if (conv_params.C <= 2) {
gemm_k_iters = (implicit_K + bk - 1) / bk;
n_channel_specialization = conv_params.C;
} else if (conv_params.C <= 4) {
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
n_channel_specialization = conv_params.C;
}
bool small_filter = (!n_channel_specialization) &&
(conv_params.wS[0] <= 16 && conv_params.wS[1] <= 16);
// Fix host side helper params
int sign = (conv_params.flip ? -1 : 1);
int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
int inp_jump_w = sign * ijw;
int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
sign * (conv_params.wS[1] - 1) * ijw;
// Build implicit gemm params
ImplicitGemmConv2DParams gemm_params{
/* const int M = */ implicit_M,
/* const int N = */ implicit_N,
/* const int K = */ implicit_K,
/* const int gemm_k_iterations = */ gemm_k_iters,
/* const int inp_jump_w = */ inp_jump_w,
/* const int inp_jump_h = */ inp_jump_h,
/* const int inp_jump_c = */ inp_jump_c,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int swizzle_log = */ swizzle_log};
// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
: "l")
<< "_filter_" << (small_filter ? 's' : 'l');
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
size_t grid_dim_x = (implicit_N + bn - 1) / bn;
size_t grid_dim_y = (implicit_M + bm - 1) / bm;
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
size_t grid_dim_y = (tm + tile - 1) / tile;
size_t grid_dim_x = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void explicit_gemm_conv_2D_gpu(
void implicit_gemm_conv_2D_general_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
// Pad input
std::vector<int> padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
conv_params.C};
array in_padded(padded_shape, in.dtype(), nullptr, {});
// Deduce implicit gemm size
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
// Fill with zeros
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
// Determine block and warp tiles
int wm = 2, wn = 2;
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
conv_params.pad[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Make jump params
int f_wgt_jump_h =
std::lcm(conv_params.idil[0], conv_params.kdil[0]) / conv_params.kdil[0];
int f_wgt_jump_w =
std::lcm(conv_params.idil[1], conv_params.kdil[1]) / conv_params.kdil[1];
// Copy input values into the slice
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
int f_out_jump_h =
std::lcm(conv_params.idil[0], conv_params.str[0]) / conv_params.str[0];
int f_out_jump_w =
std::lcm(conv_params.idil[1], conv_params.str[1]) / conv_params.str[1];
// Make strided view
std::vector<int> strided_shape = {
conv_params.N,
conv_params.oS[0],
conv_params.oS[1],
conv_params.wS[0],
conv_params.wS[1],
conv_params.C};
int adj_out_h = (conv_params.oS[0] + f_out_jump_h - 1) / f_out_jump_h;
int adj_out_w = (conv_params.oS[1] + f_out_jump_w - 1) / f_out_jump_w;
int adj_out_hw = adj_out_h * adj_out_w;
int adj_implicit_m = conv_params.N * adj_out_hw;
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * conv_params.str[0],
in_padded.strides()[2] * conv_params.str[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
Conv2DGeneralJumpParams jump_params{
/* const int f_wgt_jump_h = */ f_wgt_jump_h,
/* const int f_wgt_jump_w = */ f_wgt_jump_w,
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
/* const int f_out_jump_h = */ f_out_jump_h,
/* const int f_out_jump_w = */ f_out_jump_w,
// Materialize strided view
std::vector<int> strided_reshape = {
conv_params.N * conv_params.oS[0] * conv_params.oS[1],
conv_params.wS[0] * conv_params.wS[1] * conv_params.C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
/* const int adj_out_h = */ adj_out_h,
/* const int adj_out_w = */ adj_out_w,
/* const int adj_out_hw = */ adj_out_hw,
/* const int adj_implicit_m = */ adj_implicit_m};
// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
return steel_matmul(
s,
d,
/*a = */ in_strided,
/*b = */ wt,
/*c = */ out,
/*M = */ strided_reshape[0],
/*N = */ conv_params.O,
/*K = */ strided_reshape[1],
/*batch_size_out = */ 1,
/*a_cols = */ strided_reshape[1],
/*b_cols = */ strided_reshape[1],
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
// Make base info
std::vector<Conv2DGeneralBaseInfo> base_h(f_out_jump_h);
std::vector<Conv2DGeneralBaseInfo> base_w(f_out_jump_w);
int jump_h = conv_params.flip ? -conv_params.kdil[0] : conv_params.kdil[0];
int jump_w = conv_params.flip ? -conv_params.kdil[1] : conv_params.kdil[1];
int init_h =
(conv_params.flip ? (conv_params.wS[0] - 1) * conv_params.kdil[0] : 0);
int init_w =
(conv_params.flip ? (conv_params.wS[1] - 1) * conv_params.kdil[1] : 0);
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * conv_params.str[0] - conv_params.pad[0] + init_h;
int wh_base = 0;
while (wh_base < conv_params.wS[0] && ih_loop % conv_params.idil[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
int wh_size =
((conv_params.wS[0] - wh_base) + f_wgt_jump_h - 1) / f_wgt_jump_h;
base_h[i] = {wh_base, wh_size};
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * conv_params.str[1] - conv_params.pad[1] + init_w;
int ww_base = 0;
while (ww_base < conv_params.wS[1] && iw_loop % conv_params.idil[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
int ww_size =
((conv_params.wS[1] - ww_base) + f_wgt_jump_w - 1) / f_wgt_jump_w;
base_w[j] = {ww_base, ww_size};
}
// Collect block sizes
int bm = adj_implicit_m >= 8192 && conv_params.C >= 64 ? 64 : 32;
int bn = (bm == 64 && implicit_N >= 64) ? 64 : 32;
int bk = 16;
int tn = (implicit_N + bn - 1) / bn;
int tm = (adj_implicit_m + bm - 1) / bm;
int swizzle_log = 0;
// Get channel iteration info
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
int gemm_k_iters = channel_k_iters;
// Fix host side helper params
int sign = (conv_params.flip ? -1 : 1);
int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
int inp_jump_w = sign * ijw;
int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
sign * (conv_params.wS[1] - 1) * ijw;
// Build implicit gemm params
ImplicitGemmConv2DParams gemm_params{
/* const int M = */ implicit_M,
/* const int N = */ implicit_N,
/* const int K = */ implicit_K,
/* const int gemm_k_iterations = */ gemm_k_iters,
/* const int inp_jump_w = */ inp_jump_w,
/* const int inp_jump_h = */ inp_jump_h,
/* const int inp_jump_c = */ inp_jump_c,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int swizzle_log = */ swizzle_log};
// Determine kernel
std::ostringstream kname;
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
size_t grid_dim_y = (tm + tile - 1) / tile;
size_t grid_dim_x = tn * tile;
size_t grid_dim_z = f_out_jump_h * f_out_jump_w;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
compute_encoder->setBytes(
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
compute_encoder->setBytes(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
@ -299,6 +437,7 @@ void winograd_conv_2D_gpu(
// Fill with zeros
array zero_arr = array(0, in.dtype());
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
copies_w.push_back(zero_arr);
// Pick input slice from padded
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
@ -327,7 +466,8 @@ void winograd_conv_2D_gpu(
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int str[NDIM] = */ {1, 1},
/* const int pad[NDIM] = */ {0, 0},
/* const int dil[NDIM] = */ {1, 1},
/* const int kdil[NDIM] = */ {1, 1},
/* const int idil[NDIM] = */ {1, 1},
/* const size_t in_strides[NDIM + 2] = */
{in_padded.strides()[0],
in_padded.strides()[1],
@ -337,6 +477,8 @@ void winograd_conv_2D_gpu(
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
/* const int groups = */ 1,
/* const bool flip = */ false,
};
int O_c = conv_params.O;
@ -460,6 +602,8 @@ void conv_2D_gpu(
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip,
std::vector<array>& copies) {
// Make conv params
MLXConvParams<2> conv_params{
@ -471,37 +615,47 @@ void conv_2D_gpu(
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
/* const int pad[NDIM] = */ {padding[0], padding[1]},
/* const int dil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
/* const int groups = */ 1,
/* const bool flip = */ flip,
};
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
bool inp_large = (conv_params.in_strides[0] >= 1ul << 18);
bool channels_large = (conv_params.C + conv_params.O) >= 512;
bool channels_med = (conv_params.C + conv_params.O) >= 256;
// Direct to winograd conv
if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
conv_params.C >= 64 && conv_params.O >= 64 && conv_params.wS[0] == 3 &&
conv_params.wS[1] == 3 && conv_params.str[0] == 1 &&
conv_params.str[1] == 1 && conv_params.dil[0] == 1 &&
conv_params.dil[1] == 1) {
winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
(channels_large || (channels_med && inp_large))) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
// Direct to implicit gemm conv
else if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0) {
implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
}
// Direct to explicit gemm conv
else if (wt_dilation[0] == 1 && wt_dilation[1] == 1) {
explicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
}
// Direct to fallback conv
else {
slow_conv_2D_gpu(s, d, in, wt, out, conv_params);
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
@ -532,11 +686,31 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
// 2D conv
if (out.ndim() == 4) {
conv_2D_gpu(
s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_, copies);
s,
d,
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_,
copies);
}
// 1D conv
else if (out.ndim() == 3) {
conv_1D_gpu(s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_);
conv_1D_gpu(
s,
d,
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
}
// Throw error
else {

View File

@ -142,7 +142,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Get kernel name
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "scatter" << type_to_name(out) << idx_type_name;
int idx_ndim = nidx ? inputs[1].ndim() : 0;
bool index_nd1_specialization = (idx_ndim == 1);
// Bail from fast path (1d index specialization) if scatter dims aren't
// the outermost dims and contiguous since update access won't be raster
// order.
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= (axes_[i] == i);
}
// Bail from fast path (1d index specialization) if any of the dims are
// broadcasted, since we can't rely on linear indexing in that case.
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= inputs[i].flags().row_contiguous;
}
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
switch (reduce_type_) {
case Scatter::None:
kname << "_none";
@ -170,85 +191,106 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setComputePipelineState(kernel);
// Collect all idx shapes and strides into one place
int idx_ndim = nidx ? inputs[1].ndim() : 0;
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
}
// Set all the buffers
set_array_buffer(compute_encoder, upd, 1);
set_array_buffer(compute_encoder, out, 2);
// Set update info
size_t upd_ndim = upd.ndim();
uint upd_ndim = upd.ndim();
size_t upd_size = 1;
for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i);
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
if (index_nd1_specialization) {
bool upd_col_contiguous = upd.flags().col_contiguous;
compute_encoder->setBytes(
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes(
out.strides().data(), out_ndim * sizeof(size_t), 8);
}
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
}
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
compute_encoder->setBytes(
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
compute_encoder->setBytes(
out.strides().data(), out_ndim * sizeof(size_t), 8);
}
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
}
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
} // namespace mlx::core

View File

@ -3,11 +3,13 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
${CMAKE_CURRENT_SOURCE_DIR}/binary.h
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
)
@ -27,6 +29,7 @@ set(
"scan"
"softmax"
"sort"
"ternary"
"unary"
"gather"
"scatter"
@ -48,11 +51,7 @@ endfunction(build_kernel_base)
function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
set(HEADERS_PADDED ${HEADERS})
if(${KERNEL} STREQUAL "conv")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/conv.h)
endif()
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS_PADDED}")
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS}")
endfunction(build_kernel)
foreach(KERNEL ${KERNELS})

View File

@ -11,8 +11,6 @@ template <typename U>
struct IndexValPair {
uint32_t index;
U val;
IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {}
};
template <typename U>
@ -65,10 +63,10 @@ struct ArgMax {
template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>(
return IndexValPair<U>{
simd_shuffle_down(data.index, delta),
simd_shuffle_down(data.val, delta)
);
};
}
@ -82,7 +80,6 @@ template <typename T, typename Op, int N_READS>
const device size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]],
threadgroup IndexValPair<T> *local_data [[threadgroup(0)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
@ -111,7 +108,9 @@ template <typename T, typename Op, int N_READS>
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
IndexValPair<T> best(0, Op::init);
IndexValPair<T> best{0, Op::init};
threadgroup IndexValPair<T> local_data[32];
// Loop over the reduction axis in lsize*N_READS buckets
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
@ -172,7 +171,6 @@ template <typename T, typename Op, int N_READS>
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
threadgroup IndexValPair<itype> *local_data [[threadgroup(0)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \

View File

@ -2,16 +2,6 @@
#include "mlx/backend/metal/kernels/binary.h"
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_s2s(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_ss(
device const T* a,

View File

@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/binary.h"
#include "mlx/backend/metal/kernels/ternary.h"
#include "mlx/backend/metal/kernels/unary.h"
typedef half float16_t;

View File

@ -1,481 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/conv_params.h"
#define MLX_MTL_CONST static constant constexpr const
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int vec_size,
int tgp_size,
int tgp_padding = 0>
struct Conv2DInputBlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = BM;
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
MLX_MTL_CONST int n_vecs = BK / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
MLX_MTL_CONST int n_rows = dst_fd / bstride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>& params;
int weight_h;
int weight_w;
int offsets_n[n_rows];
int offsets_oh[n_rows];
int offsets_ow[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoader(
const device T* src_,
threadgroup T* dst_,
const constant MLXConvParams<2>& params_,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bj),
params(params_),
weight_h(0),
weight_w(0) {
int out_n_pixels = params.oS[0] * params.oS[1];
for (int i = 0; i < n_rows; ++i) {
int offset_nhw = tid.y * BM + bi + i * bstride;
offsets_n[i] = offset_nhw / out_n_pixels;
int hw = offset_nhw % out_n_pixels;
offsets_oh[i] = hw / params.oS[1];
offsets_ow[i] = hw % params.oS[1];
}
(void)lid;
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
#pragma clang loop unroll(full)
for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
int n = offsets_n[i];
int oh = offsets_oh[i];
int ow = offsets_ow[i];
int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
// Read from input if in bounds
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
const device T* curr_src = src + n * params.in_strides[0] +
ih * params.in_strides[1] + iw * params.in_strides[2];
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = curr_src[j];
}
}
// Zero pad otherwise
else {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params.wS[1]) {
return;
}
weight_w = 0;
if (++weight_h < params.wS[0]) {
return;
}
weight_h = 0;
src += BK;
}
};
template <
typename T,
int BM,
int BN,
int BK,
int vec_size,
int tgp_size,
int tgp_padding = 0>
struct Conv2DWeightBlockLoader {
// Destination dimensions
MLX_MTL_CONST int dst_fd = BN;
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
MLX_MTL_CONST int n_vecs = BK / vec_size;
// Stride along block row within the block
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
MLX_MTL_CONST int n_rows = dst_fd / bstride;
// Leading dimension for src
const int src_ld;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>& params;
int weight_h;
int weight_w;
/* Constructor */
METAL_FUNC Conv2DWeightBlockLoader(
const device T* src_,
threadgroup T* dst_,
const constant MLXConvParams<2>& params_,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_.wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / n_vecs),
bj(vec_size * (thread_idx % n_vecs)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj),
params(params_),
weight_h(0),
weight_w(0) {
(void)lid;
(void)tid;
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
const device T* curr_src =
src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
#pragma clang loop unroll(full)
for (short i = 0; i < dst_fd; i += bstride) {
#pragma clang loop unroll(full)
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params.wS[1]) {
return;
}
weight_w = 0;
if (++weight_h < params.wS[0]) {
return;
}
weight_h = 0;
src += BK;
}
};
///////////////////////////////////////////////////////////////////////////////
// Transforms
///////////////////////////////////////////////////////////////////////////////
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
int tgp_padding_a = 0,
int tgp_padding_b = 0,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct Conv2DBlockMMA {
// Warp tile size along M
MLX_MTL_CONST int TM = BM / (WM * 8);
// Warp tile size along N
MLX_MTL_CONST int TN = BN / (WN * 8);
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
MLX_MTL_CONST int TN_stride = 8 * WN;
// Leading dimensions of threadgroup A, B blocks
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
// Strides of A, B along reduction axis
MLX_MTL_CONST short simd_stride_a =
transpose_a ? TM_stride : TM_stride * lda_tgp;
MLX_MTL_CONST short simd_stride_b =
transpose_b ? TN_stride * ldb_tgp : TN_stride;
// Jump between elements
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
// Offsets within threadgroup
const int tm;
const int tn;
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
short sm;
short sn;
/* Constructor */
METAL_FUNC Conv2DBlockMMA(
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Iterate over BK in blocks of 8
#pragma clang loop unroll(full)
for (short kk = 0; kk < BK; kk += 8) {
short2 offset_a =
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
short2 offset_b =
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
As__ += simd_stride_a;
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
Bs__ += simd_stride_b;
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
#pragma clang loop unroll(full)
for (short i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (short j = 0; j < TN; j++) {
simdgroup_multiply_accumulate(
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device T* C, const int ldc) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
METAL_FUNC void
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
#pragma clang loop unroll(full)
for (int i = 0; i < TM; i++) {
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
#pragma clang loop unroll(full)
for (int j = 0; j < TN; j++) {
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
}
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
}
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<T, AccumType>>
struct Conv2DImplicitGEMMKernel {
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
MLX_MTL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
MLX_MTL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
MLX_MTL_CONST short tgp_size = WM * WN * 32;
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
using loader_a_t =
Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
using loader_b_t =
Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
using mma_t = Conv2DBlockMMA<
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
tgp_padding_a,
tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const int K = params.wt_strides[0];
const int N = params.O;
B += c_col * K;
C += c_row * N + c_col;
// Prepare threadgroup memory for loading
threadgroup T* As = tgp_memory;
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
mma_op.store_result(C, N);
}
};

View File

@ -1,16 +1,102 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/conv_params.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/conv.h"
#define MLX_MTL_CONST static constant constexpr const
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
/// Slow and naive kernels
/// Naive unfold with dilation
///////////////////////////////////////////////////////////////////////////////
template <typename T, int N>
[[kernel]] void naive_unfold_Nd(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C;
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
int out_pixels = 1;
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
// Set out
out += gid.z * filter_size + gid.y * (params->C);
// Corrdinates in input
int is[N] = {0};
// gid.z: N oS (Batch and row in unfolded output)
// gid.y: wS (Filter location to unfold input)
// gid.x: C (channel)
int n = (gid.z) / out_pixels;
int oS = (gid.z) % out_pixels;
int wS = gid.y;
bool valid = n < params->N;
// Unroll dimensions
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
int is_max = 1 + params->idil[i] * (params->iS[i] - 1);
valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);
is[i] = is_ / params->idil[i];
oS /= params->oS[i];
wS /= params->wS[i];
}
if(valid) {
size_t in_offset = n * params->in_strides[0];
for(int i = 0; i < N; ++i) {
in_offset += is[i] * params->in_strides[i + 1];
}
out[gid.x] = in[in_offset + gid.x];
} else {
out[gid.x] = T(0);
}
}
#define instantiate_naive_unfold_nd(name, itype, n) \
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
[[kernel]] void naive_unfold_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]);
#define instantiate_naive_unfold_nd_dims(name, itype) \
instantiate_naive_unfold_nd(name, itype, 1) \
instantiate_naive_unfold_nd(name, itype, 2) \
instantiate_naive_unfold_nd(name, itype, 3)
instantiate_naive_unfold_nd_dims(float32, float);
instantiate_naive_unfold_nd_dims(float16, half);
instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Slow and naive conv2d kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
@ -58,8 +144,8 @@ template <typename T,
// Local in
for(int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.dil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.dil[1];
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
@ -116,59 +202,6 @@ instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Implicit gemm kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
device T* out [[buffer(2)]],
const constant MLXConvParams<2>& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemm_kernel = Conv2DImplicitGEMMKernel<T, BM, BN, BK, WM, WN, /*transpose_a*/ false, /*transpose_b*/ true>;
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
gemm_kernel::run(
in, wt, out,
params, tgp_memory,
tid, lid, simd_gid, simd_lid
);
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_conv_2d(name, itype, 32, 32, 32, 2, 2) \
instantiate_implicit_conv_2d(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_conv_2d(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Winograd kernels
///////////////////////////////////////////////////////////////////////////////

View File

@ -1,19 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
template <int NDIM>
struct MLXConvParams {
const int N; // Batch size
const int C; // In channels
const int O; // Out channels
const int iS[NDIM]; // Input spatial dim
const int wS[NDIM]; // Weight spatial dim
const int oS[NDIM]; // Output spatial dim
const int str[NDIM]; // Kernel strides
const int pad[NDIM]; // Input padding
const int dil[NDIM]; // Kernel dilation
const size_t in_strides[NDIM + 2]; // In strides
const size_t wt_strides[NDIM + 2]; // Wt strides
const size_t out_strides[NDIM + 2]; // Out strides
};

View File

@ -13,6 +13,58 @@ using namespace metal;
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX> \
METAL_FUNC void scatter_1d_index_impl(
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const constant bool& upd_col_contiguous [[buffer(6)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
Op op;
uint out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(
idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];
}
if (!upd_col_contiguous) {
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
} else {
op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x);
}
}
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter_1d_index( \
const device T *updates [[buffer(1)]], \
device mlx_atomic<T> *out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
const constant bool& upd_col_contiguous [[buffer(6)]], \
IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \
\
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
\
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
out_shape, \
out_strides, \
upd_size, \
upd_col_contiguous, \
idx_buffers, \
gid); \
\
}
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_impl(
@ -46,10 +98,14 @@ METAL_FUNC void scatter_impl(
out_idx += idx_val * out_strides[ax];
}
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
if (upd_size > 1) {
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
out_idx += out_offset;
}
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
op.atomic_update(out, updates[upd_idx], out_idx);
}
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
@ -90,9 +146,11 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
axes, \
idxs, \
gid); \
}
}
#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
#define make_scatter(n) \
make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) \
make_scatter_1d_index(IDX_ARG_ ##n, IDX_ARR_ ##n)
make_scatter(0)
make_scatter(1)
@ -129,8 +187,21 @@ template [[host_name("scatter" name "_" #nidx)]] \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter_1d_index" name "_" #nidx)]] \
[[kernel]] void scatter_1d_index<src_t, idx_t, op_t, nidx>( \
const device src_t *updates [[buffer(1)]], \
device mlx_atomic<src_t> *out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
const constant bool& upd_col_contiguous [[buffer(6)]], \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \

View File

@ -0,0 +1,11 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/loader.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
using namespace metal;
using namespace mlx::steel;

View File

@ -0,0 +1,189 @@
// Copyright © 2024 Apple Inc.
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
int N_CHANNELS = 0,
bool SMALL_FILTER = false>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using namespace mlx::steel;
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DInputBlockLoaderSmallChannels<
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>,
// Else go to general loader
typename metal::conditional_t<
// Check if filter size is small enough
SMALL_FILTER,
// Go to small filter specialization
Conv2DInputBlockLoaderSmallFilter<
T, BM, BN, BK, tgp_size, tgp_padding_a>,
// Else go to large filter generalization
Conv2DInputBlockLoaderLargeFilter<
T, BM, BN, BK, tgp_size, tgp_padding_a>
>
>;
// Weight loader
using loader_b_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DWeightBlockLoaderSmallChannels<
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>,
// Else go to general loader
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>
>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
const int N = gemm_params->N;
B += c_col * K;
C += c_row * N + c_col;
const int2 offsets_a(0, c_row);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations = gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
short tgp_bm = min(BM, gemm_params->M - c_row);
short tgp_bn = min(BN, gemm_params->N - c_col);
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);

View File

@ -0,0 +1,209 @@
// Copyright © 2024 Apple Inc.
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;
using namespace mlx::steel;
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
typename AccumType = float,
typename Epilogue = TransformNone<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t = Conv2DInputBlockLoaderGeneral<
T, BM, BN, BK, tgp_size, tgp_padding_a>;
// Weight loader
using loader_b_t = Conv2DWeightBlockLoaderGeneral<
T, BM, BN, BK, tgp_size, tgp_padding_b>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int tid_z = tid.z;
const int base_oh = tid_z / jump_params->f_out_jump_w;
const int base_ow = tid_z % jump_params->f_out_jump_w;
const int base_wh = base_h[base_oh].weight_base;
const int base_ww = base_w[base_ow].weight_base;
const int base_wh_size = base_h[base_oh].weight_size;
const int base_ww_size = base_w[base_ow].weight_size;
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
B += c_col * K;
const int4 offsets_a(0, c_row, base_oh, base_ow);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, offsets_a, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, offsets_b, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations = base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
{
// Adjust for simdgroup and thread locatio
int offset_m = c_row + mma_op.sm + mma_op.tm;
int offset_n = c_col + mma_op.sn + mma_op.tn;
C += offset_n;
if (offset_n >= gemm_params->N)
return;
short diff = gemm_params->N - offset_n;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < mma_t::TM; i++) {
int cm = offset_m + i * mma_t::TM_stride;
int n = cm / jump_params->adj_out_hw;
int hw = cm % jump_params->adj_out_hw;
int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
if(n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2];
STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = mma_op.results[i * mma_t::TN + j].thread_elements();
int offset = offset_cm + (j * mma_t::TN_stride);
// Apply epilogue and output C
if (j * mma_t::TN_stride < diff) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * mma_t::TN_stride + 1 < diff) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);

View File

@ -0,0 +1,6 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h"
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h"

View File

@ -0,0 +1,449 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short tgp_padding = 0>
struct Conv2DInputBlockLoaderLargeFilter {
// Destination dimensions
STEEL_CONST short BROWS = BM;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const constant MLXConvParams<2>* params;
const constant ImplicitGemmConv2DParams* gemm_params;
short weight_h;
short weight_w;
const device T* src[n_rows];
int read_n[n_rows];
int read_ih[n_rows];
int read_iw[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoaderLargeFilter(
const device T* src_,
threadgroup T* dst_,
const int2 offsets,
const constant MLXConvParams<2>* params_,
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
params(params_),
gemm_params(gemm_params_),
weight_h(0),
weight_w(0) {
int out_n_pixels = params->oS[0] * params->oS[1];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
int offset_nhw = offsets.y + bi + i * TROWS;
int n = offset_nhw / out_n_pixels;
int hw = offset_nhw % out_n_pixels;
int oh = hw / params->oS[1];
int ow = hw % params->oS[1];
int ih = oh * params->str[0] - params->pad[0];
int iw = ow * params->str[1] - params->pad[1];
read_n[i] = n;
read_ih[i] = ih;
read_iw[i] = iw;
// Adjust for flip
if (params->flip) {
ih += (params->wS[0] - 1) * params->kdil[0];
iw += (params->wS[1] - 1) * params->kdil[1];
}
// Read from input if in bounds
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
iw * params->in_strides[2] + bj;
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
// Find bounds
int n = read_n[i];
int ih = read_ih[i] + weight_h * params->kdil[0];
int iw = read_iw[i] + weight_w * params->kdil[1];
// Read from input if in bounds
if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
(iw >= 0 && iw < params->iS[1])) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = src[i][j];
}
}
// Zero pad otherwise
else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params->wS[1]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += gemm_params->inp_jump_w;
}
return;
}
weight_w = 0;
if (++weight_h < params->wS[0]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += gemm_params->inp_jump_h;
}
return;
}
weight_h = 0;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += gemm_params->inp_jump_c;
}
}
};
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short tgp_padding = 0>
struct Conv2DInputBlockLoaderSmallFilter {
// Destination dimensions
STEEL_CONST short BROWS = BM;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
using mask_t = short;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const constant MLXConvParams<2>* params;
const constant ImplicitGemmConv2DParams* gemm_params;
short weight_h;
short weight_w;
const device T* src[n_rows];
mask_t mask_h[n_rows];
mask_t mask_w[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoaderSmallFilter(
const device T* src_,
threadgroup T* dst_,
const int2 offsets,
const constant MLXConvParams<2>* params_,
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
params(params_),
gemm_params(gemm_params_),
weight_h(0),
weight_w(0) {
int out_n_pixels = params->oS[0] * params->oS[1];
int read_n[n_rows];
int read_ih[n_rows];
int read_iw[n_rows];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
int offset_nhw = offsets.y + bi + i * TROWS;
int n = offset_nhw / out_n_pixels;
int hw = offset_nhw % out_n_pixels;
int oh = hw / params->oS[1];
int ow = hw % params->oS[1];
int ih = oh * params->str[0] - params->pad[0];
int iw = ow * params->str[1] - params->pad[1];
read_n[i] = n;
read_ih[i] = ih;
read_iw[i] = iw;
// Adjust for flip
if (params->flip) {
ih += (params->wS[0] - 1) * params->kdil[0];
iw += (params->wS[1] - 1) * params->kdil[1];
}
// Read from input if in bounds
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
iw * params->in_strides[2] + bj;
}
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
mask_h[i] = 0;
mask_w[i] = 0;
}
for (short kh = 0; kh < params->wS[0]; kh++) {
short flip_h = params->flip ? params->wS[0] - kh - 1 : kh;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
int n = read_n[i];
int ih = read_ih[i] + flip_h * params->kdil[0];
bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0];
mask_h[i] |= (in_bounds << kh);
}
}
for (short kw = 0; kw < params->wS[1]; kw++) {
short flip_w = params->flip ? params->wS[1] - kw - 1 : kw;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
int iw = read_iw[i] + flip_w * params->kdil[1];
bool in_bounds = iw >= 0 && iw < params->iS[1];
mask_w[i] |= (in_bounds << kw);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
mask_t h_mask = mask_t(1) << weight_h;
mask_t w_mask = mask_t(1) << weight_w;
STEEL_PRAGMA_UNROLL
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
// Read from input if in bounds
if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = src[i][j];
}
}
// Zero pad otherwise
else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_w < params->wS[1]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += gemm_params->inp_jump_w;
}
return;
}
weight_w = 0;
if (++weight_h < params->wS[0]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += gemm_params->inp_jump_h;
}
return;
}
weight_h = 0;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += gemm_params->inp_jump_c;
}
}
};
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short tgp_padding = 0>
struct Conv2DWeightBlockLoader {
// Destination dimensions
STEEL_CONST short BROWS = BN;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size =
(BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
// Leading dimension for src
const int src_ld;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>* params;
int weight_hw;
const int read_n;
const bool do_read;
/* Constructor */
METAL_FUNC Conv2DWeightBlockLoader(
const device T* src_,
threadgroup T* dst_,
const int2 offsets,
const constant MLXConvParams<2>* params_,
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_->wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj),
params(params_),
weight_hw(0),
read_n(offsets.y + bi),
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
if (BN != 8 || do_read) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BN; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = src[i * src_ld + j];
}
}
} else {
for (short i = 0; i < BN; i += TROWS) {
if ((read_n + i) < params->O) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = src[i * src_ld + j];
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
if (++weight_hw < (params->wS[1] * params->wS[0])) {
src += params->wt_strides[2];
return;
}
weight_hw = 0;
src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
}
};
} // namespace steel
} // namespace mlx

View File

@ -0,0 +1,319 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <short n_channels_>
struct ChannelHelper {
STEEL_CONST short n_channels = n_channels_;
STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8;
STEEL_CONST short excess = vec_size - n_channels_;
};
template <>
struct ChannelHelper<1> {
STEEL_CONST short n_channels = 1;
STEEL_CONST short vec_size = 1;
STEEL_CONST short excess = 0;
};
template <>
struct ChannelHelper<2> {
STEEL_CONST short n_channels = 2;
STEEL_CONST short vec_size = 2;
STEEL_CONST short excess = 0;
};
template <>
struct ChannelHelper<3> {
STEEL_CONST short n_channels = 3;
STEEL_CONST short vec_size = 4;
STEEL_CONST short excess = 1;
};
template <>
struct ChannelHelper<4> {
STEEL_CONST short n_channels = 4;
STEEL_CONST short vec_size = 4;
STEEL_CONST short excess = 0;
};
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short n_channels,
short tgp_padding = 0>
struct Conv2DInputBlockLoaderSmallChannels {
// Destination dimensions
STEEL_CONST short BROWS = BM;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const constant MLXConvParams<2>* params;
const constant ImplicitGemmConv2DParams* gemm_params;
short weight_hw;
const device T* src[n_rows];
int read_n[n_rows];
int read_ih[n_rows];
int read_iw[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoaderSmallChannels(
const device T* src_,
threadgroup T* dst_,
const int2 offsets,
const constant MLXConvParams<2>* params_,
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
params(params_),
gemm_params(gemm_params_),
weight_hw(thread_idx % TCOLS) {
int out_n_pixels = params->oS[0] * params->oS[1];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
int offset_nhw = offsets.y + bi + i * TROWS;
int n = offset_nhw / out_n_pixels;
int hw = offset_nhw % out_n_pixels;
int oh = hw / params->oS[1];
int ow = hw % params->oS[1];
int ih = oh * params->str[0] - params->pad[0];
int iw = ow * params->str[1] - params->pad[1];
// Read from input if in bounds
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
iw * params->in_strides[2];
read_n[i] = n;
read_ih[i] = ih;
read_iw[i] = iw;
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
if (weight_hw >= params->wS[1] * params->wS[0]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
int wh = (weight_hw / params->wS[1]);
int ww = (weight_hw % params->wS[1]);
int flip_h = params->flip ? params->wS[0] - wh - 1 : wh;
int flip_w = params->flip ? params->wS[1] - ww - 1 : ww;
int weight_h = flip_h * params->kdil[0];
int weight_w = flip_w * params->kdil[1];
STEEL_PRAGMA_UNROLL
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
// Find bounds
int n = read_n[i];
int ih = read_ih[i] + weight_h;
int iw = read_iw[i] + weight_w;
// Read from input if in bounds
if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
(iw >= 0 && iw < params->iS[1])) {
const device T* curr_src = src[i] + weight_h * params->in_strides[1] +
weight_w * params->in_strides[2];
STEEL_PRAGMA_UNROLL
for (short j = 0; j < n_channels; ++j) {
dst[is * dst_ld + j] = curr_src[j];
}
STEEL_PRAGMA_UNROLL
for (short j = n_channels; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
// Zero pad otherwise
else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
weight_hw += TCOLS;
}
};
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short n_channels,
short tgp_padding = 0>
struct Conv2DWeightBlockLoaderSmallChannels {
// Destination dimensions
STEEL_CONST short BROWS = BN;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
// Leading dimension for src
const int src_ld;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>* params;
int weight_hw;
const int read_n;
const bool do_read;
/* Constructor */
METAL_FUNC Conv2DWeightBlockLoaderSmallChannels(
const device T* src_,
threadgroup T* dst_,
const int2 offsets,
const constant MLXConvParams<2>* params_,
const constant ImplicitGemmConv2DParams* gemm_params_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_->wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld),
params(params_),
weight_hw(thread_idx % TCOLS),
read_n(offsets.y + bi),
do_read(read_n + BN <= gemm_params_->N) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
if (bi >= BROWS || bj >= BCOLS)
return;
if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
const device T* curr_src = src + weight_hw * params->wt_strides[2];
if (BN != 8 || do_read) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < n_channels; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
STEEL_PRAGMA_UNROLL
for (short j = n_channels; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
} else {
for (short i = 0; i < BROWS; i += TROWS) {
if (((read_n + i) < params->O)) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < n_channels; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
STEEL_PRAGMA_UNROLL
for (short j = n_channels; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
weight_hw += TCOLS;
}
};
} // namespace steel
} // namespace mlx

View File

@ -0,0 +1,288 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short tgp_padding = 0>
struct Conv2DInputBlockLoaderGeneral {
// Destination dimensions
STEEL_CONST short BROWS = BM;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const constant MLXConvParams<2>* params;
const constant Conv2DGeneralJumpParams* jump_params;
const short base_wh;
const short base_ww;
short weight_h;
short weight_w;
const device T* src[n_rows];
int read_n[n_rows];
int read_ih[n_rows];
int read_iw[n_rows];
/* Constructor */
METAL_FUNC Conv2DInputBlockLoaderGeneral(
const device T* src_,
threadgroup T* dst_,
const int4 offsets,
const constant MLXConvParams<2>* params_,
const constant Conv2DGeneralJumpParams* jump_params_,
const short base_wh_,
const short base_ww_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
params(params_),
jump_params(jump_params_),
base_wh(base_wh_),
base_ww(base_ww_),
weight_h(base_wh_),
weight_w(base_ww_) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; ++i) {
int offset_nhw = offsets.y + bi + i * TROWS;
int n = offset_nhw / jump_params->adj_out_hw;
int hw = offset_nhw % jump_params->adj_out_hw;
int oh =
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z;
int ow =
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w;
int ih = oh * params->str[0] - params->pad[0];
int iw = ow * params->str[1] - params->pad[1];
read_n[i] = n;
read_ih[i] = ih;
read_iw[i] = iw;
// Read from input if in bounds
src[i] = src_ + n * params->in_strides[0] + bj;
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
// Find bounds
int n = read_n[i];
int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
int ih_dil = read_ih[i] + h_flip * params->kdil[0];
int iw_dil = read_iw[i] + w_flip * params->kdil[1];
int ih = ih_dil / params->idil[0];
int iw = iw_dil / params->idil[1];
size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
// Read from input if in bounds
if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
(iw_dil >= 0 && iw < params->iS[1])) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = (src[i])[offset + j];
}
}
// Zero pad otherwise
else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; ++j) {
dst[is * dst_ld + j] = T(0);
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
weight_w += jump_params->f_wgt_jump_w;
if (weight_w < params->wS[1]) {
return;
}
weight_w = base_ww;
weight_h += jump_params->f_wgt_jump_h;
if (weight_h < params->wS[0]) {
return;
}
weight_h = base_wh;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < n_rows; i++) {
src[i] += BK;
}
}
};
template <
typename T,
short BM,
short BN,
short BK,
short tgp_size,
short tgp_padding = 0>
struct Conv2DWeightBlockLoaderGeneral {
// Destination dimensions
STEEL_CONST short BROWS = BN;
STEEL_CONST short BCOLS = BK;
// Read dimensions
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
STEEL_CONST short vec_size =
(BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
// Thread read shape
STEEL_CONST short TCOLS = BCOLS / vec_size;
STEEL_CONST short TROWS = tgp_size / TCOLS;
// Rows / strided reads within the block
STEEL_CONST short n_rows = BROWS / TROWS;
// Leading dimension for src
const int src_ld;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
const constant MLXConvParams<2>* params;
const constant Conv2DGeneralJumpParams* jump_params;
const short base_wh;
const short base_ww;
short weight_h;
short weight_w;
const int start_row;
/* Constructor */
METAL_FUNC Conv2DWeightBlockLoaderGeneral(
const device T* src_,
threadgroup T* dst_,
const int2 offsets,
const constant MLXConvParams<2>* params_,
const constant Conv2DGeneralJumpParams* jump_params_,
const short base_wh_,
const short base_ww_,
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(params_->wt_strides[0]),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj),
params(params_),
jump_params(jump_params_),
base_wh(base_wh_),
base_ww(base_ww_),
weight_h(base_wh_),
weight_w(base_ww_),
start_row(offsets.y + bi) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
const device T* curr_src = src + weight_h * params->wt_strides[1] +
weight_w * params->wt_strides[2];
if ((start_row + BN <= params->O)) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BN; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
}
} else {
for (short i = 0; i < BN; i += TROWS) {
if ((start_row + i) < params->O) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
weight_w += jump_params->f_wgt_jump_w;
if (weight_w < params->wS[1]) {
return;
}
weight_w = base_ww;
weight_h += jump_params->f_wgt_jump_h;
if (weight_h < params->wS[0]) {
return;
}
weight_h = base_wh;
src += BK;
}
};
} // namespace steel
} // namespace mlx

View File

@ -0,0 +1,62 @@
// Copyright © 2024 Apple Inc.
#pragma once
template <int NDIM>
struct MLXConvParams {
const int N; // Batch size
const int C; // In channels
const int O; // Out channels
const int iS[NDIM]; // Input spatial dim
const int wS[NDIM]; // Weight spatial dim
const int oS[NDIM]; // Output spatial dim
const int str[NDIM]; // Kernel strides
const int pad[NDIM]; // Input padding
const int kdil[NDIM]; // Kernel dilation
const int idil[NDIM]; // Input dilation
const size_t in_strides[NDIM + 2]; // In strides
const size_t wt_strides[NDIM + 2]; // Wt strides
const size_t out_strides[NDIM + 2]; // Out strides
const int groups; // Input channel groups
const bool flip;
};
namespace mlx {
namespace steel {
struct ImplicitGemmConv2DParams {
const int M;
const int N;
const int K;
const int gemm_k_iterations;
const int inp_jump_w;
const int inp_jump_h;
const int inp_jump_c;
const int tiles_n;
const int tiles_m;
const int swizzle_log;
};
struct Conv2DGeneralJumpParams {
const int f_wgt_jump_h;
const int f_wgt_jump_w;
const int f_out_jump_h;
const int f_out_jump_w;
const int adj_out_h;
const int adj_out_w;
const int adj_out_hw;
const int adj_implicit_m;
};
struct Conv2DGeneralBaseInfo {
int weight_base;
int weight_size;
};
} // namespace steel
} // namespace mlx

View File

@ -4,6 +4,7 @@
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"

View File

@ -2,9 +2,15 @@
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
@ -167,6 +173,9 @@ struct BlockMMA {
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
@ -236,6 +245,9 @@ struct BlockMMA {
D += (sm + tm) * ldd + tn + sn;
dst_tile_dims -= short2(tn + sn, sm + tm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {

View File

@ -1,5 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/gemm/params.h"

View File

@ -3,7 +3,6 @@
#pragma once
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/host.h"
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")

View File

@ -0,0 +1,10 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return condition ? x : y;
}
};

View File

@ -0,0 +1,201 @@
// Copyright © 2023 Apple Inc.
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/ternary.h"
template <typename T, typename Op>
[[kernel]] void ternary_op_v(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
uint index [[thread_position_in_grid]]) {
d[index] = Op()(a[index], b[index], c[index]);
}
template <typename T, typename Op>
[[kernel]] void ternary_op_g_nd1(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t& a_strides,
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_strides);
auto b_idx = elem_to_loc_1(index, b_strides);
auto c_idx = elem_to_loc_1(index, c_strides);
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op>
[[kernel]] void ternary_op_g_nd2(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
auto c_idx = elem_to_loc_2(index, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op>
[[kernel]] void ternary_op_g_nd3(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3(index, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, int DIM>
[[kernel]] void ternary_op_g_nd(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
constant const size_t c_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
template <typename T, typename Op>
[[kernel]] void ternary_op_g(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
#define instantiate_ternary_v(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void ternary_op_v<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
uint index [[thread_position_in_grid]]); \
#define instantiate_ternary_g(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void ternary_op_g<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const size_t* c_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
#define instantiate_ternary_g_dim(name, type, op, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void ternary_op_g_nd<type, op, dims>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
constant const size_t c_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
#define instantiate_ternary_g_nd(name, type, op) \
template [[host_name(name "_1")]] \
[[kernel]] void ternary_op_g_nd1<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t& a_strides, \
constant const size_t& b_strides, \
constant const size_t& c_strides, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void ternary_op_g_nd2<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
constant const size_t c_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void ternary_op_g_nd3<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
constant const size_t c_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_ternary_g_dim(name, type, op, 4) \
instantiate_ternary_g_dim(name, type, op, 5) \
#define instantiate_ternary_all(name, tname, type, op) \
instantiate_ternary_v("v" #name #tname, type, op) \
instantiate_ternary_g("g" #name #tname, type, op) \
instantiate_ternary_g_nd("g" #name #tname, type, op) \
#define instantiate_ternary_types(name, op) \
instantiate_ternary_all(name, bool_, bool, op) \
instantiate_ternary_all(name, uint8, uint8_t, op) \
instantiate_ternary_all(name, uint16, uint16_t, op) \
instantiate_ternary_all(name, uint32, uint32_t, op) \
instantiate_ternary_all(name, uint64, uint64_t, op) \
instantiate_ternary_all(name, int8, int8_t, op) \
instantiate_ternary_all(name, int16, int16_t, op) \
instantiate_ternary_all(name, int32, int32_t, op) \
instantiate_ternary_all(name, int64, int64_t, op) \
instantiate_ternary_all(name, float16, half, op) \
instantiate_ternary_all(name, float32, float, op) \
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
instantiate_ternary_all(name, complex64, complex64_t, op) \
instantiate_ternary_types(select, Select)

View File

@ -9,6 +9,10 @@
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/utils.h"
namespace {
constant float inf = metal::numeric_limits<float>::infinity();
}
struct Abs {
template <typename T>
T operator()(T x) {

View File

@ -91,6 +91,30 @@ inline size_t elem_to_loc(
return loc;
}
template <int NDIM>
inline uint3 elem_to_loc_3_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM],
constant const size_t c_strides[NDIM]) {
uint3 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
static_cast<uint>(
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.z += l * c_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
inline uint2 elem_to_loc_2_nd(
uint3 elem,
@ -150,6 +174,30 @@ inline size_t elem_to_loc(
return loc;
}
inline uint3 elem_to_loc_3_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
int ndim) {
uint3 loc = {
static_cast<uint>(
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
static_cast<uint>(
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.z += l * c_strides[d];
elem.z /= shape[d];
}
return loc;
}
inline uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,

View File

@ -8,7 +8,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/host.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/matmul.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <future>
@ -10,6 +10,10 @@
namespace mlx::core::metal {
bool is_available() {
return true;
}
int max_ops_per_buffer() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@ -11,16 +11,52 @@
namespace mlx::core::metal {
constexpr bool is_available() {
#ifdef _METAL_
return true;
#else
return false;
#endif
}
bool is_available();
bool cache_enabled(void);
void set_cache_enabled(bool enabled);
/* Get the actively used memory in bytes.
*
* Note, this will not always match memory use reported by the system because
* it does not include cached memory buffers.
* */
size_t get_active_memory();
/* Get the peak amount of used memory in bytes.
*
* The maximum memory used is recorded from the beginning of the program
* execution.
* */
size_t get_peak_memory();
/* Get the cache size in bytes.
*
* The cache includes memory not currently used that has not been returned
* to the system allocator.
* */
size_t get_cache_memory();
/* Set the memory limit.
* Calls to malloc will wait on scheduled tasks if the limit is exceeded. If
* there are no more scheduled tasks an error will be raised if relaxed
* is false or memory will be allocated (including the potential for
* swap) if relaxed is true.
*
* The memory limit defaults to 1.5 times the maximum recommended working set
* size reported by the device.
*
* Returns the previous memory limit.
* */
size_t set_memory_limit(size_t limit, bool relaxed = true);
/* Set the free cache limit.
* If using more than the given limit, free memory will be reclaimed
* from the cache on the next allocation. To disable the cache,
* set the limit to 0.
*
* The cache limit defaults to the memory limit.
*
* Returns the previous cache limit.
* */
size_t set_cache_limit(size_t limit);
void new_stream(Stream stream);
std::shared_ptr<void> new_scoped_memory_pool();

View File

@ -1,11 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/ternary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
@ -43,24 +43,25 @@ void binary_op(
std::ostringstream kname;
switch (bopt) {
case ScalarScalar:
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case VectorVector:
case BinaryOpType::VectorVector:
kname << "vv";
break;
case General:
case BinaryOpType::General:
kname << "g";
break;
}
kname << op << type_to_name(a);
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
if (bopt == BinaryOpType::General &&
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
@ -80,7 +81,7 @@ void binary_op(
set_array_buffer(compute_encoder, outputs[0], 2);
set_array_buffer(compute_encoder, outputs[1], 3);
if (bopt == General) {
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
@ -141,24 +142,25 @@ void binary_op(
std::ostringstream kname;
switch (bopt) {
case ScalarScalar:
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case VectorVector:
case BinaryOpType::VectorVector:
kname << "vv";
break;
case General:
case BinaryOpType::General:
kname << "g";
break;
}
kname << op << type_to_name(a);
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
if (bopt == BinaryOpType::General &&
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
@ -173,7 +175,7 @@ void binary_op(
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
set_array_buffer(compute_encoder, out, 2);
if (bopt == General) {
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
@ -202,7 +204,94 @@ void binary_op(
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = bopt == General ? out.size() : out.data_size();
size_t nthreads =
bopt == BinaryOpType::General ? out.size() : out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
void ternary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
assert(inputs.size() == 3);
auto& a = inputs[0];
auto& b = inputs[1];
auto& c = inputs[2];
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_c = strides[2];
auto& strides_out = strides[3];
std::ostringstream kname;
if (topt == TernaryOpType::General) {
kname << "g";
kname << op << type_to_name(b);
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
} else {
kname << "v";
kname << op << type_to_name(b);
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, c, 2);
set_array_buffer(compute_encoder, out, 3);
if (topt == TernaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 8);
}
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
@ -430,8 +519,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
compute_encoder->setThreadgroupMemoryLength(
simd_size * (sizeof(uint32_t) + in.itemsize()), 0);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
}
@ -621,6 +708,10 @@ void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "mul");
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
ternary_op(inputs, out, "select");
}
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "neg");
}

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <stdexcept>
@ -6,6 +6,10 @@
namespace mlx::core::metal {
bool is_available() {
return false;
}
void new_stream(Stream) {}
std::shared_ptr<void> new_scoped_memory_pool() {
return nullptr;
@ -19,10 +23,21 @@ std::function<void()> make_task(
"[metal::make_task] Cannot make GPU task without metal backend");
}
// No cache for CPU only
bool cache_enabled(void) {
return false;
// No-ops when Metal is not available.
size_t get_active_memory() {
return 0;
}
size_t get_peak_memory() {
return 0;
}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t, bool) {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
void set_cache_enabled(bool) {}
} // namespace mlx::core::metal

View File

@ -80,6 +80,7 @@ NO_GPU(Reshape)
NO_GPU(Round)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(Select)
NO_GPU(Sigmoid)
NO_GPU(Sign)
NO_GPU(Sin)

View File

@ -47,6 +47,10 @@ bool is_binary(const Primitive& p) {
typeid(p) == typeid(Subtract));
}
bool is_ternary(const Primitive& p) {
return typeid(p) == typeid(Select);
}
bool is_broadcast(const Primitive& p) {
return typeid(p) == typeid(Broadcast);
}
@ -60,14 +64,16 @@ bool is_reduction(const Primitive& p) {
}
bool is_fusable(const Primitive& p) {
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||
is_noop(p);
}
bool allows_shapeless(const Primitive& p) {
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition);
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
typeid(p) == typeid(Select);
}
Compiled::Compiled(

View File

@ -46,10 +46,6 @@ struct Dtype {
};
};
inline bool is_available(const Dtype& dtype) {
return true;
}
static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};

View File

@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <climits>
#include <cmath>
#include <numeric>
#include <set>
@ -73,10 +74,24 @@ array arange(
if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) {
throw std::invalid_argument("[arange] Cannot compute length.");
}
double real_size = std::ceil((stop - start) / step);
if (std::isnan(real_size)) {
if (std::isinf(start) || std::isinf(stop)) {
throw std::invalid_argument("[arange] Cannot compute length.");
}
// Check if start and stop specify a valid range because if not, we have to
// return an empty array
if (std::isinf(step) &&
(step > 0 && start < stop || step < 0 && start > stop)) {
return array({start}, dtype);
}
double real_size = std::ceil((stop - start) / step);
if (real_size > INT_MAX) {
throw std::invalid_argument("[arange] Maximum size exceeded.");
}
int size = std::max(static_cast<int>(real_size), 0);
return array(
{size},
@ -1149,13 +1164,20 @@ array isneginf(const array& a, StreamOrDevice s /* = {} */) {
}
array where(
const array& condition,
const array& x,
const array& y,
const array& a,
const array& b,
const array& c,
StreamOrDevice s /* = {} */) {
// TODO, fix this to handle the NaN case when x has infs
auto mask = astype(condition, bool_, s);
return add(multiply(x, mask, s), multiply(y, logical_not(mask, s), s), s);
auto condition = astype(a, bool_, s);
Dtype out_dtype = promote_types(b.dtype(), c.dtype());
auto inputs = broadcast_arrays(
{condition, astype(b, out_dtype, s), astype(c, out_dtype, s)}, s);
return array(
inputs[0].shape(),
out_dtype,
std::make_unique<Select>(to_stream(s)),
inputs);
}
array allclose(
@ -1678,7 +1700,7 @@ array argpartition(
int kth_ = kth < 0 ? kth + a.shape(axis) : kth;
if (kth_ < 0 || kth_ >= a.shape(axis_)) {
std::ostringstream msg;
msg << "[argpartition] Received invalid kth " << kth << "along axis "
msg << "[argpartition] Received invalid kth " << kth << " along axis "
<< axis << " for array with shape: " << a.shape();
throw std::invalid_argument(msg.str());
}
@ -1699,24 +1721,28 @@ array topk(const array& a, int k, StreamOrDevice s /* = {}*/) {
array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {
// Check for valid axis
int axis_ = axis < 0 ? axis + a.ndim() : axis;
int kth_ = k < 0 ? k + a.shape(axis) : k;
if (axis_ < 0 || axis_ >= static_cast<int>(a.ndim())) {
std::ostringstream msg;
msg << "[topk] Received invalid axis " << axis << " for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (kth_ < 0 || kth_ >= a.shape(axis_)) {
if (k < 0 || k > a.shape(axis_)) {
std::ostringstream msg;
msg << "[topk] Received invalid k " << k << "along axis " << axis
msg << "[topk] Received invalid k=" << k << " along axis " << axis
<< " for array with shape: " << a.shape();
throw std::invalid_argument(msg.str());
}
array a_partitioned = partition(a, kth_, axis_, s);
// Return early if the whole input was requested.
if (k == a.shape(axis_)) {
return a;
}
array a_partitioned = partition(a, -k, axis_, s);
std::vector<int> slice_starts(a.ndim(), 0);
std::vector<int> slice_ends = a.shape();
slice_starts[axis_] = kth_;
slice_starts[axis_] = a.shape(axis_) - k;
return slice(a_partitioned, slice_starts, slice_ends, s);
}
@ -1733,7 +1759,11 @@ array logsumexp(
StreamOrDevice s /* = {}*/) {
auto maxval = stop_gradient(max(a, axes, true, s));
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
return add(out, reshape(maxval, out.shape(), s), s);
out = add(out, reshape(maxval, out.shape(), s), s);
if (!keepdims) {
maxval = squeeze(maxval, axes, s);
}
return where(isinf(maxval, s), maxval, out, s);
}
array logsumexp(
@ -2670,33 +2700,78 @@ array cummin(
namespace {
// Conv helpers
inline int conv_out_axis_size(
int in_dim,
int wt_dim,
int stride,
int padding,
int dilation) {
int ker = dilation * (wt_dim - 1);
return ((in_dim + 2 * padding - ker - 1) / stride) + 1;
inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {
return ((in_dim + padding - wt_dim) / stride) + 1;
}
// Conv helpers
inline int dilate_size(int dim, int dil) {
return 1 + dil * (dim - 1);
}
inline std::vector<int> conv_out_shape(
const std::vector<int>& in_shape,
const std::vector<int>& wt_shape,
const std::vector<int>& strides,
const std::vector<int>& pads,
const std::vector<int>& dilation) {
const std::vector<int>& pads_lo,
const std::vector<int>& pads_hi,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation) {
int N = in_shape[0];
int O = wt_shape[0];
std::vector<int> out_shape(in_shape.size());
int i = 0;
out_shape[i++] = N;
int spatial_dims = in_shape.size() - 2;
if (strides.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid strides " << strides << "for " << spatial_dims
<< "D convolution.";
throw std::invalid_argument(msg.str());
}
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid pading " << pads_lo << " | " << pads_hi << "for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (kernel_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid kernel dilation " << kernel_dilation << "for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
if (input_dilation.size() != spatial_dims) {
std::ostringstream msg;
msg << "[conv] Invalid input dilation " << input_dilation << "for "
<< spatial_dims << "D convolution.";
throw std::invalid_argument(msg.str());
}
for (; i < in_shape.size() - 1; i++) {
if (pads[i - 1] < 0) {
if (kernel_dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Kernel dilation sizes must be positive."
<< " Got kernel dilation " << kernel_dilation << ".";
throw std::invalid_argument(msg.str());
}
if (input_dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Input dilation sizes must be positive."
<< " Got input dilation " << input_dilation << ".";
throw std::invalid_argument(msg.str());
}
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
std::ostringstream msg;
msg << "[conv] Padding sizes must be non-negative."
<< " Got padding " << pads << ".";
<< " Got padding " << pads_lo << " | " << pads_hi << ".";
throw std::invalid_argument(msg.str());
}
@ -2707,22 +2782,19 @@ inline std::vector<int> conv_out_shape(
throw std::invalid_argument(msg.str());
}
if (dilation[i - 1] <= 0) {
std::ostringstream msg;
msg << "[conv] Dilation sizes must be positive."
<< " Got dilation " << dilation << ".";
throw std::invalid_argument(msg.str());
}
int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]);
int id = dilate_size(in_shape[i], input_dilation[i - 1]);
out_shape[i] = conv_out_axis_size(
in_shape[i], wt_shape[i], strides[i - 1], pads[i - 1], dilation[i - 1]);
id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]);
if (out_shape[i] <= 0) {
std::ostringstream msg;
msg << "[conv] Spatial dimensions of input after padding "
<< " cannot be smaller than weight spatial dimensions."
<< " Got input with shape " << in_shape << " and padding " << pads
<< " for weight of shape " << wt_shape << ".";
<< " Got error at axis " << i << " for input with shape " << in_shape
<< ", padding low " << pads_lo << ", padding high " << pads_hi
<< ", and weight of shape " << wt_shape << ".";
throw std::invalid_argument(msg.str());
}
}
@ -2777,43 +2849,16 @@ array conv1d(
int dilation /* = 1 */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
// Run checks
if (groups != 1) {
throw std::invalid_argument("[conv1d] Cannot handle groups != 1 yet");
}
if (dilation != 1) {
throw std::invalid_argument("[conv1d] Cannot handle dilation != 1 yet");
}
// Run checks
run_conv_checks(in_, wt_, 1);
auto in = in_;
auto wt = wt_;
// Type promotion
auto out_type = promote_types(in.dtype(), wt.dtype());
in = astype(in, out_type, s);
wt = astype(wt, out_type, s);
std::vector<int> strides_vec = {stride};
std::vector<int> padding_vec = {padding};
std::vector<int> dilation_vec = {dilation};
// Get output shapes
std::vector<int> out_shape = conv_out_shape(
in.shape(), wt.shape(), strides_vec, padding_vec, dilation_vec);
return array(
out_shape,
in.dtype(),
std::make_unique<Convolution>(
to_stream(s),
padding_vec,
strides_vec,
dilation_vec,
std::vector<int>(1, 1)),
{in, wt});
return conv_general(
/* const array& input = */ in_,
/* const array& weight = */ wt_,
/* std::vector<int> stride = */ {stride},
/* std::vector<int> padding = */ {padding},
/* std::vector<int> kernel_dilation = */ {dilation},
/* std::vector<int> input_dilation = */ {1},
/* int groups = */ groups,
/* bool flip = */ false,
s);
}
/** 2D convolution with a filter */
@ -2825,42 +2870,98 @@ array conv2d(
const std::pair<int, int>& dilation /* = {1, 1} */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_general(
/* const array& input = */ in_,
/* const array& weight = */ wt_,
/* std::vector<int> stride = */ {stride.first, stride.second},
/* std::vector<int> padding = */ {padding.first, padding.second},
/* std::vector<int> kernel_dilation = */
{dilation.first, dilation.second},
/* std::vector<int> input_dilation = */ {1, 1},
/* int groups = */ groups,
/* bool flip = */ false,
s);
}
/** General convolution with a filter */
array conv_general(
array in,
array wt,
std::vector<int> stride /* = {} */,
std::vector<int> padding_lo /* = {} */,
std::vector<int> padding_hi /* = {} */,
std::vector<int> kernel_dilation /* = {} */,
std::vector<int> input_dilation /* = {} */,
int groups /* = 1 */,
bool flip /* = false */,
StreamOrDevice s /* = {} */) {
// Run checks
if (groups != 1) {
throw std::invalid_argument("[conv2d] Cannot handle groups != 1 yet");
throw std::invalid_argument("[conv] Cannot handle groups != 1 yet");
}
if (dilation.first != 1 || dilation.second != 1) {
throw std::invalid_argument("[conv2d] Cannot handle dilation != 1 yet");
int spatial_dims = in.ndim() - 2;
if (spatial_dims < 1 || spatial_dims > 2) {
throw std::invalid_argument(
"[conv] Can only work with inputs that have 1 or 2 spatial dimensions."
" The inputs must be in the format [N, ..., C_in]");
}
// Run checks
run_conv_checks(in_, wt_, 2);
auto in = in_;
auto wt = wt_;
run_conv_checks(in, wt, spatial_dims);
// Type promotion
auto out_type = promote_types(in.dtype(), wt.dtype());
in = astype(in, out_type, s);
wt = astype(wt, out_type, s);
std::vector<int> strides_vec = {stride.first, stride.second};
std::vector<int> padding_vec = {padding.first, padding.second};
std::vector<int> dilation_vec = {dilation.first, dilation.second};
if (stride.size() <= 1) {
int stride_int = stride.size() ? stride[0] : 1;
stride = std::vector<int>(spatial_dims, stride_int);
}
if (padding_lo.size() <= 1) {
int padding_int = padding_lo.size() ? padding_lo[0] : 0;
padding_lo = std::vector<int>(spatial_dims, padding_int);
}
if (padding_hi.size() <= 1) {
int padding_int = padding_hi.size() ? padding_hi[0] : 0;
padding_hi = std::vector<int>(spatial_dims, padding_int);
}
if (kernel_dilation.size() <= 1) {
int kernel_dilation_int = kernel_dilation.size() ? kernel_dilation[0] : 1;
kernel_dilation = std::vector<int>(spatial_dims, kernel_dilation_int);
}
if (input_dilation.size() <= 1) {
int input_dilation_int = input_dilation.size() ? input_dilation[0] : 1;
input_dilation = std::vector<int>(spatial_dims, input_dilation_int);
}
// Get output shapes
std::vector<int> out_shape = conv_out_shape(
in.shape(), wt.shape(), strides_vec, padding_vec, dilation_vec);
in.shape(),
wt.shape(),
stride,
padding_lo,
padding_hi,
kernel_dilation,
input_dilation);
return array(
out_shape,
in.dtype(),
std::make_unique<Convolution>(
to_stream(s),
padding_vec,
strides_vec,
dilation_vec,
std::vector<int>(2, 1)),
stride,
padding_lo,
kernel_dilation,
input_dilation,
groups,
flip),
{in, wt});
}
@ -3388,6 +3489,17 @@ array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {
return a;
}
std::vector<array> atleast_1d(
const std::vector<array>& arrays,
StreamOrDevice s /* = {} */) {
std::vector<array> out;
out.reserve(arrays.size());
for (const auto& a : arrays) {
out.push_back(atleast_1d(a, s));
}
return out;
}
array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
switch (a.ndim()) {
case 0:
@ -3399,6 +3511,17 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
}
}
std::vector<array> atleast_2d(
const std::vector<array>& arrays,
StreamOrDevice s /* = {} */) {
std::vector<array> out;
out.reserve(arrays.size());
for (const auto& a : arrays) {
out.push_back(atleast_2d(a, s));
}
return out;
}
array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
switch (a.ndim()) {
case 0:
@ -3411,4 +3534,16 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
return a;
}
}
std::vector<array> atleast_3d(
const std::vector<array>& arrays,
StreamOrDevice s /* = {} */) {
std::vector<array> out;
out.reserve(arrays.size());
for (const auto& a : arrays) {
out.push_back(atleast_3d(a, s));
}
return out;
}
} // namespace mlx::core

View File

@ -1026,6 +1026,43 @@ array cummin(
/** Convolution operations */
/** General convolution with a filter */
array conv_general(
array input,
array weight,
std::vector<int> stride = {},
std::vector<int> padding_lo = {},
std::vector<int> padding_hi = {},
std::vector<int> kernel_dilation = {},
std::vector<int> input_dilation = {},
int groups = 1,
bool flip = false,
StreamOrDevice s = {});
/** General convolution with a filter */
inline array conv_general(
const array& input,
const array& weight,
std::vector<int> stride = {},
std::vector<int> padding = {},
std::vector<int> kernel_dilation = {},
std::vector<int> input_dilation = {},
int groups = 1,
bool flip = false,
StreamOrDevice s = {}) {
return conv_general(
/* const array& input = */ input,
/* const array& weight = */ weight,
/* std::vector<int> stride = */ stride,
/* std::vector<int> padding_lo = */ padding,
/* std::vector<int> padding_hi = */ padding,
/* std::vector<int> kernel_dilation = */ kernel_dilation,
/* std::vector<int> input_dilation = */ input_dilation,
/* int groups = */ groups,
/* bool flip = */ flip,
/* StreamOrDevice s = */ s);
}
/** 1D convolution with a filter */
array conv1d(
const array& input,
@ -1123,7 +1160,16 @@ std::vector<array> depends(
/** convert an array to an atleast ndim array */
array atleast_1d(const array& a, StreamOrDevice s = {});
std::vector<array> atleast_1d(
const std::vector<array>& a,
StreamOrDevice s = {});
array atleast_2d(const array& a, StreamOrDevice s = {});
std::vector<array> atleast_2d(
const std::vector<array>& a,
StreamOrDevice s = {});
array atleast_3d(const array& a, StreamOrDevice s = {});
std::vector<array> atleast_3d(
const std::vector<array>& a,
StreamOrDevice s = {});
} // namespace mlx::core

View File

@ -48,6 +48,54 @@ std::tuple<array, array, int> vmap_binary_op(
return {a, b, to_ax};
}
std::tuple<array, array, array, int> vmap_ternary_op(
const std::vector<array>& inputs,
const std::vector<int>& axes,
const Stream& stream) {
assert(inputs.size() == 3);
assert(axes.size() == 3);
auto a = inputs[0];
auto b = inputs[1];
auto c = inputs[2];
int ndim = std::max(
{a.ndim() + (axes[0] == -1),
b.ndim() + (axes[1] == -1),
c.ndim() + (axes[2] == -1)});
auto expand_dims = [stream, ndim](auto in) {
auto shape = in.shape();
shape.insert(shape.begin(), ndim - shape.size(), 1);
return reshape(in, shape, stream);
};
int to_ax = (ndim - a.ndim()) + axes[0];
int from_ax1 = (ndim - b.ndim()) + axes[1];
int from_ax2 = (ndim - c.ndim()) + axes[2];
a = expand_dims(a);
b = expand_dims(b);
c = expand_dims(c);
auto find_tdims = [](auto x, int to_ax, int from_ax) {
std::vector<int> tdims(x.ndim());
std::iota(tdims.begin(), tdims.end(), 0);
tdims.erase(tdims.begin() + from_ax);
tdims.insert(tdims.begin() + to_ax, from_ax);
return tdims;
};
if (to_ax != from_ax1) {
std::vector<int> tdims = find_tdims(b, to_ax, from_ax1);
b = transpose(b, tdims, stream);
}
if (to_ax != from_ax2) {
std::vector<int> tdims = find_tdims(c, to_ax, from_ax2);
c = transpose(c, tdims, stream);
}
return {a, b, c, to_ax};
}
} // namespace
std::vector<array> Primitive::jvp(
@ -631,21 +679,13 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
return axis_ == c_other.axis_;
}
std::vector<array> Convolution::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 2);
std::vector<array> grads;
// Collect info
auto& in = primals[0];
auto& wt = primals[1];
auto cotan = cotangents[0];
int O = wt.shape(0);
array conv_weight_backward_patches(
const array& in,
const array& wt,
const array& cotan,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding,
StreamOrDevice s) {
// Resolve Padded input shapes and strides
std::vector<int> padding_starts(in.ndim(), 0);
std::vector<int> padding_ends = in.shape();
@ -653,9 +693,9 @@ std::vector<array> Convolution::vjp(
// padded shape
for (int i = 1; i < in.ndim() - 1; i++) {
in_padded_shape[i] += 2 * padding_[i - 1];
padding_ends[i] += padding_[i - 1];
padding_starts[i] += padding_[i - 1];
in_padded_shape[i] += 2 * padding[i - 1];
padding_ends[i] += padding[i - 1];
padding_starts[i] += padding[i - 1];
}
// padded strides (contiguous)
@ -664,6 +704,12 @@ std::vector<array> Convolution::vjp(
in_padded_strides[i] = in_padded_strides[i + 1] * in_padded_shape[i + 1];
}
// Pad input
std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1);
auto in_padded =
pad(in, padded_axes, padding, padding, array(0, in.dtype()), s);
// Resolve strided patches
// patches are shaped as
@ -678,62 +724,108 @@ std::vector<array> Convolution::vjp(
std::vector<size_t> patches_strides(patches_shape.size(), 1);
patches_strides[0] = in_padded_strides[0];
for (int i = 1; i < n_spatial_dim + 1; i++) {
patches_strides[i] = in_padded_strides[i] * kernel_strides_[i - 1];
patches_strides[i] = in_padded_strides[i] * kernel_strides[i - 1];
}
for (int i = 1; i < in.ndim(); i++) {
patches_strides[n_spatial_dim + i] = in_padded_strides[i];
}
// Reshape cotangents and weights for gemm
cotan = reshape(cotangents[0], {-1, O}, stream());
auto weight_reshaped = reshape(wt, {O, -1}, stream());
// Make patches from in
auto in_patches = as_strided(in_padded, patches_shape, patches_strides, 0, s);
// Prepare for matmul
int O = wt.shape(0);
auto cotan_mat = reshape(cotan, {-1, O}, s);
in_patches = reshape(in_patches, {cotan_mat.shape(0), -1}, s);
auto grad = matmul(transpose(cotan_mat, {1, 0}, s), in_patches, s);
grad = reshape(grad, wt.shape(), s);
return grad;
}
std::vector<array> Convolution::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 2);
std::vector<array> grads;
// Collect info
auto& in = primals[0];
auto& wt = primals[1];
auto& cotan = cotangents[0];
for (int a : argnums) {
// Grads for input
if (a == 0) {
// Gemm with cotangents to get patches
auto grad_patches = matmul(cotan, weight_reshaped, stream());
std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_;
// Prepare base grad array to accumulate on
int in_padded_size = in_padded_strides[0] * in_padded_shape[0];
auto grad = zeros(
{
in_padded_size,
},
in.dtype(),
for (int i = 0; i < padding_lo.size(); ++i) {
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding_[i] - 1;
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
padding_hi[i] = in_size - out_size + padding_[i];
}
auto wt_trans = swapaxes(wt, 0, -1, stream());
auto grad = conv_general(
/* const array& input = */ cotan,
/* const array& weight = */ wt_trans,
/* std::vector<int> stride = */ input_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_dilation_,
/* std::vector<int> input_dilation = */ kernel_strides_,
/* int groups = */ 1,
/* bool flip = */ !flip_,
stream());
// Create index map
int patches_size = grad_patches.size();
auto idx = arange(in_padded_size, stream());
idx = as_strided(idx, patches_shape, patches_strides, 0, stream());
idx = reshape(idx, {patches_size}, stream());
// Flatten patches and scatter
auto flat_patches = reshape(grad_patches, {patches_size, 1}, stream());
grad = scatter_add(grad, idx, flat_patches, 0, stream());
// Reshape and slice away padding
grad = reshape(grad, in_padded_shape, stream());
grad = slice(grad, padding_starts, padding_ends, stream());
grads.push_back(grad);
}
// Grads for weight
else if (a == 1) {
// Make patches from in
std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1);
auto in_padded = pad(
in, padded_axes, padding_, padding_, array(0, in.dtype()), stream());
auto in_patches =
as_strided(in_padded, patches_shape, patches_strides, 0, stream());
in_patches = reshape(in_patches, {cotan.shape(0), -1}, stream());
bool no_dilation = true;
auto grad =
matmul(transpose(cotan, {1, 0}, stream()), in_patches, stream());
grad = reshape(grad, wt.shape(), stream());
grads.push_back(grad);
for (int i = 0; i < input_dilation_.size(); i++) {
no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1);
}
if (no_dilation) {
auto grad = conv_weight_backward_patches(
in, wt, cotan, kernel_strides_, padding_, stream());
grads.push_back(grad);
} else {
std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_;
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
}
auto in_trans = swapaxes(in, 0, -1, stream());
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto grad_trans = conv_general(
/* const array& input = */ in_trans,
/* const array& weight = */ cotan_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_strides_,
/* std::vector<int> input_dilation = */ input_dilation_,
/* int groups = */ 1,
/* bool flip = */ flip_,
stream());
auto grad = swapaxes(grad_trans, 0, -1, stream());
grads.push_back(grad);
}
}
}
@ -745,7 +837,8 @@ bool Convolution::is_equivalent(const Primitive& other) const {
return padding_ == c_other.padding_ &&
kernel_strides_ == c_other.kernel_strides_ &&
kernel_dilation_ == c_other.kernel_dilation_ &&
input_dilation_ == c_other.input_dilation_;
input_dilation_ == c_other.input_dilation_ &&
groups_ == c_other.groups_ && flip_ == c_other.flip_;
}
std::vector<array> Copy::vjp(
@ -1775,6 +1868,76 @@ std::pair<std::vector<array>, std::vector<int>> Multiply::vmap(
return {{multiply(a, b, stream())}, {to_ax}};
}
std::vector<array> Select::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 3);
assert(tangents.size() == 3);
auto jvp_fun = [&](int i) {
int arg = argnums[i];
if (arg == 0) {
return zeros_like(primals[0], stream());
} else if (arg == 1) {
return multiply(
astype(primals[0], tangents[1].dtype(), stream()),
tangents[1],
stream());
} else {
return multiply(
astype(
logical_not(primals[0], stream()), tangents[2].dtype(), stream()),
tangents[2],
stream());
}
};
array jvp = jvp_fun(argnums[0]);
for (int i = 1; i < argnums.size(); i++) {
jvp = add(jvp, jvp_fun(argnums[i]));
}
return {jvp};
}
std::vector<array> Select::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 3);
assert(cotangents.size() == 1);
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg == 0) {
vjps.push_back(zeros_like(primals[0], stream()));
} else if (arg == 1) {
vjps.push_back(multiply(
astype(primals[0], cotangents[0].dtype(), stream()),
cotangents[0],
stream()));
} else if (arg == 2) {
vjps.push_back(multiply(
astype(
logical_not(primals[0], stream()),
cotangents[0].dtype(),
stream()),
cotangents[0],
stream()));
}
}
return vjps;
}
std::pair<std::vector<array>, std::vector<int>> Select::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto [a, b, c, to_ax] = vmap_ternary_op(inputs, axes, stream());
return {{where(a, b, c, stream())}, {to_ax}};
}
std::vector<array> Negative::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
@ -1925,7 +2088,10 @@ std::vector<array> Power::vjp(
primals[1],
stream()));
} else {
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
auto& exp = outputs[0];
auto exp_vjp = multiply(log(primals[0], stream()), outputs[0], stream());
// 0 * log 0 -> 0
vjps.push_back(where(exp, exp_vjp, array(0.0f, exp.dtype()), stream()));
}
vjps.back() = multiply(cotangents[0], vjps.back(), stream());
}

View File

@ -544,15 +544,19 @@ class Convolution : public UnaryPrimitive {
public:
explicit Convolution(
Stream stream,
const std::vector<int>& padding,
const std::vector<int>& kernel_strides,
const std::vector<int>& padding,
const std::vector<int>& kernel_dilation,
const std::vector<int>& input_dilation)
const std::vector<int>& input_dilation,
const int groups = 1,
const bool flip = false)
: UnaryPrimitive(stream),
padding_(padding),
kernel_strides_(kernel_strides),
kernel_dilation_(kernel_dilation),
input_dilation_(input_dilation){};
input_dilation_(input_dilation),
groups_(groups),
flip_(flip){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -571,6 +575,8 @@ class Convolution : public UnaryPrimitive {
std::vector<int> kernel_strides_;
std::vector<int> kernel_dilation_;
std::vector<int> input_dilation_;
int groups_;
bool flip_;
void eval(const std::vector<array>& inputs, array& out);
};
@ -719,6 +725,23 @@ class DivMod : public Primitive {
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};
class Select : public UnaryPrimitive {
public:
explicit Select(Stream stream) : UnaryPrimitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Select)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Remainder : public UnaryPrimitive {
public:
explicit Remainder(Stream stream) : UnaryPrimitive(stream){};

View File

@ -37,6 +37,7 @@ from mlx.nn.layers.activations import (
relu,
relu6,
selu,
sigmoid,
silu,
softmax,
softplus,
@ -67,3 +68,4 @@ from mlx.nn.layers.transformer import (
TransformerEncoder,
TransformerEncoderLayer,
)
from mlx.nn.layers.upsample import Upsample

View File

@ -18,7 +18,7 @@ def _make_activation_module(f):
@partial(mx.compile, shapeless=True)
def sigmoid(x):
r"""Applies the element-wise function:
r"""Applies the sigmoid function.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
@ -142,11 +142,11 @@ def log_sigmoid(x):
@partial(mx.compile, shapeless=True)
def gelu(x):
def gelu(x) -> mx.array:
r"""Applies the Gaussian Error Linear Units function.
.. math::
\\textrm{GELU}(x) = x * \Phi(x)
\textrm{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Gaussian CDF.
@ -185,11 +185,15 @@ def gelu_fast_approx(x):
.. math::
x = x \sigma\left(1.773 x\right)
x = x \sigma\left(1.702 x\right)
where :math:`\sigma(\cdot)` is the logistic sigmoid.
References:
- https://github.com/hendrycks/GELUs
- https://arxiv.org/abs/1606.08415
"""
return x * mx.sigmoid(1.773 * x)
return x * mx.sigmoid(1.702 * x)
def glu(x: mx.array, axis: int = -1) -> mx.array:
@ -199,7 +203,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
\textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``
@ -260,6 +264,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
@partial(mx.compile, shapeless=True)
def mish(x: mx.array) -> mx.array:
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
Reference: https://arxiv.org/abs/1908.08681
@ -297,7 +302,7 @@ class GLU(Module):
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
textrm{GLU}(x) = a * \sigma(b)
\textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``

View File

@ -7,6 +7,42 @@ import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)
elif isinstance(value, Module):
return {
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in value.items()
if filter_fn(value, k, v)
}
elif isinstance(value, dict):
nd = {}
for k, v in v.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, v)
else {}
)
return nd
elif isinstance(value, list):
nl = []
for i, vi in enumerate(value):
tk = f"{value_key}.{i}"
nl.append(
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, vi)
else {}
)
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
class Module(dict):
"""Base class for building neural networks with MLX.
@ -98,10 +134,13 @@ class Module(dict):
if key in self:
return self[key]
else:
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
super(Module, self).__getattribute__(key)
def __setattr__(self, key: str, val: Any):
self[key] = val
if isinstance(val, (mx.array, dict, list, tuple)):
self[key] = val
else:
super(Module, self).__setattr__(key, val)
def load_weights(
self,
@ -245,31 +284,11 @@ class Module(dict):
is_leaf_fn = is_leaf_fn or (
lambda m, k, v: not isinstance(v, (Module, dict, list))
)
def unwrap(vk, v):
if is_leaf_fn(self, vk, v):
return map_fn(v)
if isinstance(v, Module):
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
if isinstance(v, dict):
nd = {}
for k, v in v.items():
tk = f"{vk}.{k}"
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
return nd
if isinstance(v, list):
nl = []
for i, vi in enumerate(v):
tk = f"{vk}.{i}"
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
return {
k: _unwrap(self, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in self.items()
if filter_fn(self, k, v)
}
def parameters(self):
"""Recursively return all the :class:`mlx.core.array` members of this Module

View File

@ -0,0 +1,205 @@
# Copyright © 2023-2024 Apple Inc.
import operator
from functools import reduce
from itertools import product
from typing import Literal, Tuple, Union
import mlx.core as mx
from mlx.nn.layers.base import Module
def _scaled_indices(N, scale, align_corners, dim, ndims):
M = int(scale * N)
if align_corners:
indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1))
else:
step = 1 / scale
start = ((M - 1) * step - N + 1) / 2
indices = mx.arange(M, dtype=mx.float32) * step - start
indices = mx.clip(indices, 0, N - 1)
shape = [1] * ndims
shape[dim] = -1
return indices.reshape(shape)
def _nearest_indices(N, scale, dim, ndims):
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
def _linear_indices(N, scale, align_corners, dim, ndims):
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
indices_l = mx.floor(indices)
indices_r = mx.ceil(indices)
weight = indices - indices_l
weight = mx.expand_dims(weight, -1)
return (
(indices_l.astype(mx.int32), 1 - weight),
(indices_r.astype(mx.int32), weight),
)
def upsample_nearest(x: mx.array, scale_factor: Tuple):
dims = x.ndim - 2
if dims != len(scale_factor):
raise ValueError("A scale needs to be provided for each spatial dimension")
# Integer scale_factors means we can simply expand-broadcast and reshape
if tuple(map(int, scale_factor)) == scale_factor:
shape = list(x.shape)
for d in range(dims):
shape.insert(2 + 2 * d, 1)
x = x.reshape(shape)
for d in range(dims):
shape[2 + 2 * d] = int(scale_factor[d])
x = mx.broadcast_to(x, shape)
for d in range(dims):
shape[d + 1] *= shape[d + 2]
shape.pop(d + 2)
x = x.reshape(shape)
return x
else:
B, *N, C = x.shape
indices = [slice(None)]
for i, (n, s) in enumerate(zip(N, scale_factor)):
indices.append(_nearest_indices(n, s, i, dims))
indices = tuple(indices)
return x[indices]
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
dims = x.ndim - 2
if dims != len(scale_factor):
raise ValueError("A scale needs to be provided for each spatial dimension")
B, *N, C = x.shape
# Compute the sampling grid
indices = []
for i, (n, s) in enumerate(zip(N, scale_factor)):
indices.append(_linear_indices(n, s, align_corners, i, dims))
# Sample and compute the weights
samples = []
weights = []
for idx_weight in product(*indices):
idx, weight = zip(*idx_weight)
samples.append(x[(slice(None),) + idx])
weights.append(reduce(operator.mul, weight))
# Interpolate
return sum(wi * xi for wi, xi in zip(weights, samples))
class Upsample(Module):
r"""Upsample the input signal spatially.
The spatial dimensions are by convention dimensions ``1`` to ``x.ndim -
2``. The first is the batch dimension and the last is the feature
dimension.
For example, an audio signal would be 3D with 1 spatial dimension, an image
4D with 2 and so on and so forth.
There are two upsampling algorithms implemented nearest neighbor upsampling
and linear interpolation. Both can be applied to any number of spatial
dimensions and the linear interpolation will be bilinear, trilinear etc
when applied to more than one spatial dimension.
.. note::
When using one of the linear interpolation modes the ``align_corners``
argument changes how the corners are treated in the input image. If
``align_corners=True`` then the top and left edge of the input and
output will be matching as will the bottom right edge.
Parameters:
scale_factor (float or tuple): The multiplier for the spatial size.
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
Otherwise, the number of scale factors provided must match the
number of spatial dimensions.
mode (str, optional): The upsampling algorithm, either ``"nearest"`` or
``"linear"``. Default: ``"nearest"``.
align_corners (bool, optional): Changes the way the corners are treated
during ``"linear"`` upsampling. See the note above and the
examples below for more details. Default: ``False``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.arange(1, 5).reshape((1, 2, 2, 1))
>>> x
array([[[[1],
[2]],
[[3],
[4]]]], dtype=int32)
>>> n = nn.Upsample(scale_factor=2, mode='nearest')
>>> n(x).squeeze()
array([[1, 1, 2, 2],
[1, 1, 2, 2],
[3, 3, 4, 4],
[3, 3, 4, 4]], dtype=int32)
>>> b = nn.Upsample(scale_factor=2, mode='linear')
>>> b(x).squeeze()
array([[1, 1.25, 1.75, 2],
[1.5, 1.75, 2.25, 2.5],
[2.5, 2.75, 3.25, 3.5],
[3, 3.25, 3.75, 4]], dtype=float32)
>>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)
>>> b(x).squeeze()
array([[1, 1.33333, 1.66667, 2],
[1.66667, 2, 2.33333, 2.66667],
[2.33333, 2.66667, 3, 3.33333],
[3, 3.33333, 3.66667, 4]], dtype=float32)
"""
def __init__(
self,
scale_factor: Union[float, Tuple],
mode: Literal["nearest", "linear"] = "nearest",
align_corners: bool = False,
):
super().__init__()
if mode not in ["nearest", "linear"]:
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
if isinstance(scale_factor, (list, tuple)):
self.scale_factor = tuple(map(float, scale_factor))
else:
self.scale_factor = float(scale_factor)
self.mode = mode
self.align_corners = align_corners
def _extra_repr(self) -> str:
return (
f"scale_factor={self.scale_factor}, mode={self.mode!r}, "
f"align_corners={self.align_corners}"
)
def __call__(self, x: mx.array) -> mx.array:
dims = x.ndim - 2
if dims <= 0:
raise ValueError(
f"[Upsample] The input should have at least 1 spatial "
f"dimension which means it should be at least 3D but "
f"{x.ndim}D was provided"
)
scale_factor = self.scale_factor
if isinstance(scale_factor, tuple):
if len(scale_factor) != dims:
raise ValueError(
f"[Upsample] One scale per spatial dimension is required but "
f"scale_factor={scale_factor} and the number of spatial "
f"dimensions were {dims}"
)
else:
scale_factor = (scale_factor,) * dims
if self.mode == "nearest":
return upsample_nearest(x, scale_factor)
else:
return upsample_linear(x, scale_factor, self.align_corners)

View File

@ -1,11 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Callable, List
import mlx.core as mx
def exponential_decay(init: float, decay_rate: float):
def exponential_decay(init: float, decay_rate: float) -> Callable:
r"""Make an exponential decay scheduler.
Args:
@ -30,7 +31,7 @@ def exponential_decay(init: float, decay_rate: float):
return schedule
def step_decay(init: float, decay_rate: float, step_size: int):
def step_decay(init: float, decay_rate: float, step_size: int) -> Callable:
r"""Make a step decay scheduler.
Args:
@ -57,7 +58,7 @@ def step_decay(init: float, decay_rate: float, step_size: int):
return schedule
def cosine_decay(init: float, decay_steps: int):
def cosine_decay(init: float, decay_steps: int) -> Callable:
r"""Make a cosine decay scheduler.
Args:
@ -84,3 +85,73 @@ def cosine_decay(init: float, decay_steps: int):
return init * decay
return scheduler
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
r"""Join multiple schedules to create a new schedule.
Args:
schedules (list(Callable)): A list of schedules. Schedule :math:`i+1`
receives a step count indicating the number of steps since
the :math:`i`-th boundary.
boundaries (list(int)): A list of integers of length ``len(schedules) - 1``
that indicates when to transition between schedules.
Example:
>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
>>> cosine = optim.cosine_decay(1e-1, 200)
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
>>> for _ in range(12): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.0999938, dtype=float32)
"""
if len(schedules) == 0:
raise ValueError("Must provide at least 1 schedule to join.")
if len(schedules) != len(boundaries) + 1:
raise ValueError(
f"Received {len(boundaries)} boundaries but "
f"expected {len(schedules) - 1}."
)
def schedule(step):
output = schedules[0](step)
for boundary, schedule in zip(boundaries, schedules[1:]):
output = mx.where(step < boundary, output, schedule(step - boundary))
return output
return schedule
def linear_schedule(init: float, end: float, steps: int) -> Callable:
r"""Make a linear scheduler.
Args:
init (float): Initial value.
end (float): Final value.
steps (int): Number of steps to apply the schedule over. The value is
``end`` for any steps beyond ``steps``.
Example:
>>> warmup = optim.linear_schedule(0, 1e-1, 100)
>>> optimizer = optim.Adam(learning_rate=warmup)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
>>> for _ in range(101): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.1, dtype=float32)
"""
if steps < 1:
raise ValueError(f"steps must be greater than 0, but got {steps}.")
def step_fn(step):
step = mx.minimum(step, steps)
return step * ((end - init) / steps) + init
return step_fn

View File

@ -14,6 +14,7 @@ pybind11_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
)

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cstdint>
#include <cstring>
@ -7,6 +7,7 @@
#include <pybind11/numpy.h>
#include "python/src/indexing.h"
#include "python/src/pybind11_numpy_fp16.h"
#include "python/src/utils.h"
#include "mlx/ops.h"
@ -350,55 +351,53 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
shape.push_back(np_array.shape(i));
}
// Get dtype
auto type = np_array.dtype();
// Copy data and make array
if (type.is(py::dtype::of<int>())) {
if (py::isinstance<py::array_t<int32_t>>(np_array)) {
return np_array_to_mlx_contiguous<int32_t>(
np_array, shape, dtype.value_or(int32));
} else if (type.is(py::dtype::of<uint32_t>())) {
} else if (py::isinstance<py::array_t<uint32_t>>(np_array)) {
return np_array_to_mlx_contiguous<uint32_t>(
np_array, shape, dtype.value_or(uint32));
} else if (type.is(py::dtype::of<bool>())) {
} else if (py::isinstance<py::array_t<bool>>(np_array)) {
return np_array_to_mlx_contiguous<bool>(
np_array, shape, dtype.value_or(bool_));
} else if (type.is(py::dtype::of<double>())) {
} else if (py::isinstance<py::array_t<double>>(np_array)) {
return np_array_to_mlx_contiguous<double>(
np_array, shape, dtype.value_or(float32));
} else if (type.is(py::dtype::of<float>())) {
} else if (py::isinstance<py::array_t<float>>(np_array)) {
return np_array_to_mlx_contiguous<float>(
np_array, shape, dtype.value_or(float32));
} else if (type.is(py::dtype("float16"))) {
} else if (py::isinstance<py::array_t<float16_t>>(np_array)) {
return np_array_to_mlx_contiguous<float>(
np_array, shape, dtype.value_or(float16));
} else if (type.is(py::dtype::of<uint8_t>())) {
} else if (py::isinstance<py::array_t<uint8_t>>(np_array)) {
return np_array_to_mlx_contiguous<uint8_t>(
np_array, shape, dtype.value_or(uint8));
} else if (type.is(py::dtype::of<uint16_t>())) {
} else if (py::isinstance<py::array_t<uint16_t>>(np_array)) {
return np_array_to_mlx_contiguous<uint16_t>(
np_array, shape, dtype.value_or(uint16));
} else if (type.is(py::dtype::of<uint64_t>())) {
} else if (py::isinstance<py::array_t<uint64_t>>(np_array)) {
return np_array_to_mlx_contiguous<uint64_t>(
np_array, shape, dtype.value_or(uint64));
} else if (type.is(py::dtype::of<int8_t>())) {
} else if (py::isinstance<py::array_t<int8_t>>(np_array)) {
return np_array_to_mlx_contiguous<int8_t>(
np_array, shape, dtype.value_or(int8));
} else if (type.is(py::dtype::of<int16_t>())) {
} else if (py::isinstance<py::array_t<int16_t>>(np_array)) {
return np_array_to_mlx_contiguous<int16_t>(
np_array, shape, dtype.value_or(int16));
} else if (type.is(py::dtype::of<int64_t>())) {
} else if (py::isinstance<py::array_t<int64_t>>(np_array)) {
return np_array_to_mlx_contiguous<int64_t>(
np_array, shape, dtype.value_or(int64));
} else if (type.is(py::dtype::of<std::complex<float>>())) {
} else if (py::isinstance<py::array_t<std::complex<float>>>(np_array)) {
return np_array_to_mlx_contiguous<std::complex<float>>(
np_array, shape, dtype.value_or(complex64));
} else if (type.is(py::dtype::of<std::complex<double>>())) {
} else if (py::isinstance<py::array_t<std::complex<double>>>(np_array)) {
return np_array_to_mlx_contiguous<std::complex<float>>(
np_array, shape, dtype.value_or(complex64));
} else {
std::ostringstream msg;
msg << "Cannot convert numpy array of type " << type << " to mlx array.";
msg << "Cannot convert numpy array of type " << np_array.dtype()
<< " to mlx array.";
throw std::invalid_argument(msg.str());
}
}

View File

@ -5,18 +5,88 @@
#include "mlx/backend/metal/metal.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
void init_metal(py::module_& m) {
py::module_ metal = m.def_submodule("metal", "mlx.metal");
metal.def("is_available", &metal::is_available);
metal.def(
"cache_enabled",
&metal::cache_enabled,
"check if metal buffer cache is enabled, default is true");
"is_available",
&metal::is_available,
R"pbdoc(
Check if the Metal back-end is available.
)pbdoc");
metal.def(
"set_cache_enabled",
&metal::set_cache_enabled,
"enable or disable metal buffer cache");
"get_active_memory",
&metal::get_active_memory,
R"pbdoc(
Get the actively used memory in bytes.
Note, this will not always match memory use reported by the system because
it does not include cached memory buffers.
)pbdoc");
metal.def(
"get_peak_memory",
&metal::get_peak_memory,
R"pbdoc(
Get the peak amount of used memory in bytes.
The maximum memory used is recorded from the beginning of the program
execution.
)pbdoc");
metal.def(
"get_cache_memory",
&metal::get_cache_memory,
R"pbdoc(
Get the cache size in bytes.
The cache includes memory not currently used that has not been returned
to the system allocator.
)pbdoc");
metal.def(
"set_memory_limit",
&metal::set_memory_limit,
"limit"_a,
py::kw_only(),
"relaxed"_a = true,
R"pbdoc(
Set the memory limit.
Memory allocations will wait on scheduled tasks to complete if the limit
is exceeded. If there are no more scheduled tasks an error will be raised
if ``relaxed`` is ``False``. Otherwise memory will be allocated
(including the potential for swap) if ``relaxed`` is ``True``.
The memory limit defaults to 1.5 times the maximum recommended working set
size reported by the device.
Args:
limit (int): Memory limit in bytes.
relaxed (bool, optional): If `False`` an error is raised if the limit
is exceeded. Default: ``True``
Returns:
int: The previous memory limit in bytes.
)pbdoc");
metal.def(
"set_cache_limit",
&metal::set_cache_limit,
"limit"_a,
R"pbdoc(
Set the free cache limit.
If using more than the given limit, free memory will be reclaimed
from the cache on the next allocation. To disable the cache, set
the limit to ``0``.
The cache limit defaults to the memory limit. See
:func:`set_memory_limit` for more details.
Args:
limit (int): The cache limit in bytes.
Returns:
int: The previous cache limit in bytes.
)pbdoc");
}

View File

@ -3081,7 +3081,7 @@ void init_ops(py::module_& m) {
py::kw_only(),
"stream"_a = none,
R"pbdoc(
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array
2D convolution over an input with several channels
@ -3105,6 +3105,114 @@ void init_ops(py::module_& m) {
array: The convolved array.
)pbdoc");
m.def(
"conv_general",
[](const array& input,
const array& weight,
const std::variant<int, std::vector<int>>& stride,
const std::variant<
int,
std::vector<int>,
std::pair<std::vector<int>, std::vector<int>>>& padding,
const std::variant<int, std::vector<int>>& kernel_dilation,
const std::variant<int, std::vector<int>>& input_dilation,
int groups,
bool flip,
StreamOrDevice s) {
std::vector<int> stride_vec;
std::vector<int> padding_lo_vec;
std::vector<int> padding_hi_vec;
std::vector<int> kernel_dilation_vec;
std::vector<int> input_dilation_vec;
if (auto pv = std::get_if<int>(&stride); pv) {
stride_vec.push_back(*pv);
} else {
stride_vec = std::get<std::vector<int>>(stride);
}
if (auto pv = std::get_if<int>(&padding); pv) {
padding_lo_vec.push_back(*pv);
padding_hi_vec.push_back(*pv);
} else if (auto pv = std::get_if<std::vector<int>>(&padding); pv) {
padding_lo_vec = *pv;
padding_hi_vec = *pv;
} else {
auto [pl, ph] =
std::get<std::pair<std::vector<int>, std::vector<int>>>(padding);
padding_lo_vec = pl;
padding_hi_vec = ph;
}
if (auto pv = std::get_if<int>(&kernel_dilation); pv) {
kernel_dilation_vec.push_back(*pv);
} else {
kernel_dilation_vec = std::get<std::vector<int>>(kernel_dilation);
}
if (auto pv = std::get_if<int>(&input_dilation); pv) {
input_dilation_vec.push_back(*pv);
} else {
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
}
return conv_general(
/* const array& input = */ input,
/* const array& weight = */ weight,
/* std::vector<int> stride = */ stride_vec,
/* std::vector<int> padding_lo = */ padding_lo_vec,
/* std::vector<int> padding_hi = */ padding_lo_vec,
/* std::vector<int> kernel_dilation = */ kernel_dilation_vec,
/* std::vector<int> input_dilation = */ input_dilation_vec,
/* int groups = */ groups,
/* bool flip = */ flip,
s);
},
"input"_a,
"weight"_a,
py::pos_only(),
"stride"_a = 1,
"padding"_a = 0,
"kernel_dilation"_a = 1,
"input_dilation"_a = 1,
"groups"_a = 1,
"flip"_a = false,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array
General convolution over an input with several channels
.. note::
* Only 1d and 2d convolutions are supported at the moment
* the default ``groups=1`` is currently supported.
Args:
input (array): Input array of shape ``(N, ..., C_in)``
weight (array): Weight array of shape ``(C_out, ..., C_in)``
stride (int or list(int), optional): :obj:`list` with kernel strides.
All spatial dimensions get the same stride if
only one number is specified. Default: ``1``.
padding (int, list(int), or tuple(list(int), list(int)), optional):
:obj:`list` with input padding. All spatial dimensions get the same
padding if only one number is specified. Default: ``0``.
kernel_dilation (int or list(int), optional): :obj:`list` with
kernel dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
input_dilation (int or list(int), optional): :obj:`list` with
input dilation. All spatial dimensions get the same dilation
if only one number is specified. Default: ``1``
groups (int, optional): Input feature groups. Default: ``1``.
flip (bool, optional): Flip the order in which the spatial dimensions of
the weights are processed. Performs the cross-correlation operator when
``flip`` is ``False`` and the convolution operator otherwise.
Default: ``False``.
Returns:
array: The convolved array.
)pbdoc");
m.def(
"save",
&mlx_save_helper,
"file"_a,
@ -3638,62 +3746,69 @@ void init_ops(py::module_& m) {
)pbdoc");
m.def(
"atleast_1d",
&atleast_1d,
"a"_a,
py::pos_only(),
[](const py::args& arys, StreamOrDevice s) -> py::object {
if (arys.size() == 1) {
return py::cast(atleast_1d(arys[0].cast<array>(), s));
}
return py::cast(atleast_1d(arys.cast<std::vector<array>>(), s));
},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array
atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
Convert array to have at least one dimension.
Convert all arrays to have at least one dimension.
args:
a (array): Input array
Args:
*arys: Input arrays.
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
Returns:
array: An array with at least one dimension.
array or list(array): An array or list of arrays with at least one dimension.
)pbdoc");
m.def(
"atleast_2d",
&atleast_2d,
"a"_a,
py::pos_only(),
[](const py::args& arys, StreamOrDevice s) -> py::object {
if (arys.size() == 1) {
return py::cast(atleast_2d(arys[0].cast<array>(), s));
}
return py::cast(atleast_2d(arys.cast<std::vector<array>>(), s));
},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array
atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
Convert array to have at least two dimensions.
Convert all arrays to have at least two dimensions.
args:
a (array): Input array
Args:
*arys: Input arrays.
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
Returns:
array: An array with at least two dimensions.
array or list(array): An array or list of arrays with at least two dimensions.
)pbdoc");
m.def(
"atleast_3d",
&atleast_3d,
"a"_a,
py::pos_only(),
[](const py::args& arys, StreamOrDevice s) -> py::object {
if (arys.size() == 1) {
return py::cast(atleast_3d(arys[0].cast<array>(), s));
}
return py::cast(atleast_3d(arys.cast<std::vector<array>>(), s));
},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array
atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
Convert array to have at least three dimensions.
Convert all arrays to have at least three dimensions.
args:
a (array): Input array
Args:
*arys: Input arrays.
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
Returns:
array: An array with at least three dimensions.
array or list(array): An array or list of arrays with at least three dimensions.
)pbdoc");
}

View File

@ -0,0 +1,60 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
// A patch to get float16_t to work with pybind11 numpy arrays
// Derived from:
// https://github.com/pybind/pybind11/issues/1776#issuecomment-492230679
#include <pybind11/numpy.h>
namespace pybind11::detail {
template <typename T>
struct npy_scalar_caster {
PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
using Array = array_t<T>;
bool load(handle src, bool convert) {
// Taken from Eigen casters. Permits either scalar dtype or scalar array.
handle type = dtype::of<T>().attr("type"); // Could make more efficient.
if (!convert && !isinstance<Array>(src) && !isinstance(src, type))
return false;
Array tmp = Array::ensure(src);
if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
this->value = *tmp.data();
return true;
}
return false;
}
static handle cast(T src, return_value_policy, handle) {
Array tmp({1});
tmp.mutable_at(0) = src;
tmp.resize({});
// You could also just return the array if you want a scalar array.
object scalar = tmp[tuple()];
return scalar.release();
}
};
// Similar to enums in `pybind11/numpy.h`. Determined by doing:
// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
constexpr int NPY_FLOAT16 = 23;
// Kinda following:
// https://github.com/pybind/pybind11/blob/9bb3313162c0b856125e481ceece9d8faa567716/include/pybind11/numpy.h#L1000
template <>
struct npy_format_descriptor<float16_t> {
static constexpr auto name = _("float16");
static pybind11::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
return reinterpret_borrow<pybind11::dtype>(ptr);
}
};
template <>
struct type_caster<float16_t> : npy_scalar_caster<float16_t> {
static constexpr auto name = _("float16");
};
} // namespace pybind11::detail

View File

@ -11,6 +11,7 @@
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "python/src/trees.h"
namespace py = pybind11;
using namespace py::literals;
@ -30,246 +31,6 @@ std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
return vals;
}
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
std::function<void(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree) ||
py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (py::isinstance<py::dict>(subtree)) {
for (auto item : py::cast<py::dict>(subtree)) {
recurse(item.second);
}
} else {
visitor(subtree);
}
};
recurse(tree);
}
template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<py::object>& subtrees) {
int len = py::cast<T>(subtrees[0]).size();
for (auto& subtree : subtrees) {
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
}
}
}
py::object tree_map(
const std::vector<py::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform) {
std::function<py::object(const std::vector<py::object>&)> recurse;
recurse = [&](const std::vector<py::object>& subtrees) {
if (py::isinstance<py::list>(subtrees[0])) {
py::list l;
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::list>(subtrees[j])) {
items[j] = py::cast<py::list>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l.append(recurse(items));
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtrees[0])) {
// Check the rest of the subtrees
std::vector<py::object> items(subtrees.size());
int len = py::cast<py::tuple>(subtrees[0]).size();
py::tuple l(len);
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::tuple>(subtrees[j])) {
items[j] = py::cast<py::tuple>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l[i] = recurse(items);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::dict>(subtrees[0])) {
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
py::dict d;
for (auto item : py::cast<py::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::dict>(subtrees[j])) {
auto subdict = py::cast<py::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
throw std::invalid_argument(
"[tree_map] Tree is not a valid prefix tree of the first tree.");
}
items[j] = subdict[item.first];
} else {
items[j] = subtrees[j];
}
}
d[item.first] = recurse(items);
}
return py::cast<py::object>(d);
} else {
return transform(subtrees);
}
};
return recurse(trees);
}
py::object tree_map(
py::object tree,
std::function<py::object(py::handle)> transform) {
return tree_map({tree}, [&](std::vector<py::object> inputs) {
return transform(inputs[0]);
});
}
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor) {
std::function<py::object(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree)) {
auto l = py::cast<py::list>(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
return py::cast<py::object>(subtree);
} else if (py::isinstance<py::dict>(subtree)) {
auto d = py::cast<py::dict>(subtree);
for (auto item : d) {
d[item.first] = recurse(item.second);
}
return py::cast<py::object>(d);
} else if (py::isinstance<array>(subtree)) {
return visitor(subtree);
} else {
return py::cast<py::object>(subtree);
}
};
recurse(tree);
}
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(py::object& tree, const std::vector<array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](py::handle node) { return py::cast(values[index++]); });
}
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
py::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](py::handle node) {
auto arr = py::cast<array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return py::cast(it->second);
}
return py::cast(arr);
});
}
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
std::vector<array> flat_tree;
tree_visit(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return flat_tree;
}
py::object tree_unflatten(
py::object tree,
const std::vector<array>& values,
int index = 0) {
return tree_map(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}
py::object structure_sentinel() {
static py::object sentinel;
if (sentinel.ptr() == nullptr) {
sentinel = py::capsule(&sentinel);
// probably not needed but this should make certain that we won't ever
// delete the sentinel
sentinel.inc_ref();
}
return sentinel;
}
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
bool strict = true) {
auto sentinel = structure_sentinel();
std::vector<array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
return sentinel;
} else if (!strict) {
return py::cast<py::object>(obj);
} else {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return {flat_tree, structure};
}
py::object tree_unflatten_from_structure(
py::object structure,
const std::vector<array>& values,
int index = 0) {
auto sentinel = structure_sentinel();
return tree_map(structure, [&](py::handle obj) {
if (obj.is(sentinel)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}
auto validate_argnums_argnames(
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
@ -582,9 +343,69 @@ struct PyCompiledFun {
};
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
auto inputs = tree_flatten(args, false);
// Flat array inputs
std::vector<array> inputs;
auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()](
// Compilation constants which includes the tree structure of the arguments
std::vector<uint64_t> constants;
// Reserve some large primes to signify the presence of an array, a list or
// a dict in order to encode the structure of the pytree. We choose primes
// to reduce slightly the chances of these numbers occurring by a
// multiplication as values in the constants list.
constexpr uint64_t array_identifier = 18446744073709551557UL;
constexpr uint64_t list_identifier = 18446744073709551533UL;
constexpr uint64_t dict_identifier = 18446744073709551521UL;
// Flatten the tree with hashed constants and structure
std::function<void(py::handle)> recurse;
recurse = [&](py::handle obj) {
if (py::isinstance<py::list>(obj)) {
auto l = py::cast<py::list>(obj);
constants.push_back(list_identifier);
for (int i = 0; i < l.size(); ++i) {
recurse(l[i]);
}
} else if (py::isinstance<py::tuple>(obj)) {
auto l = py::cast<py::tuple>(obj);
constants.push_back(list_identifier);
for (auto item : obj) {
recurse(item);
}
} else if (py::isinstance<py::dict>(obj)) {
auto d = py::cast<py::dict>(obj);
constants.push_back(dict_identifier);
for (auto item : d) {
auto r = py::hash(item.first);
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
recurse(item.second);
}
} else if (py::isinstance<array>(obj)) {
inputs.push_back(py::cast<array>(obj));
constants.push_back(array_identifier);
} else if (py::isinstance<py::str>(obj)) {
auto r = py::hash(obj);
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else if (py::isinstance<py::int_>(obj)) {
auto r = obj.cast<int64_t>();
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else if (py::isinstance<py::float_>(obj)) {
auto r = obj.cast<double>();
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
} else {
std::ostringstream msg;
msg << "[compile] Function arguments must be trees of arrays "
<< "or constants (floats, ints, or strings), but received "
<< "type " << obj.get_type() << ".";
throw std::invalid_argument(msg.str());
}
};
recurse(args);
int num_args = inputs.size();
recurse(kwargs);
auto compile_fun = [this, &args, &kwargs, num_args](
const std::vector<array>& a) {
// Put tracers into captured inputs
std::vector<array> flat_in_captures;
@ -619,14 +440,6 @@ struct PyCompiledFun {
return outputs;
};
{
auto flat_kwargs = tree_flatten(kwargs, false);
inputs.insert(
inputs.end(),
std::make_move_iterator(flat_kwargs.begin()),
std::make_move_iterator(flat_kwargs.end()));
}
if (!py::isinstance<py::none>(captured_inputs)) {
auto flat_in_captures = tree_flatten(captured_inputs, false);
inputs.insert(
@ -635,36 +448,6 @@ struct PyCompiledFun {
std::make_move_iterator(flat_in_captures.end()));
}
// Collect the compilation constants
std::vector<uint64_t> constants;
auto value_hash = [](py::handle o) -> std::optional<uint64_t> {
// Consider expanding tuples to their contents including start and end
// ids
if (py::isinstance<py::tuple>(o) || py::isinstance<py::str>(o)) {
auto r = py::hash(o);
return *reinterpret_cast<uint64_t*>(&r);
} else if (py::isinstance<py::int_>(o)) {
auto r = o.cast<int64_t>();
return *reinterpret_cast<uint64_t*>(&r);
} else if (py::isinstance<py::float_>(o)) {
auto r = o.cast<double>();
return *reinterpret_cast<uint64_t*>(&r);
} else {
return std::nullopt;
}
};
for (int i = 0; i < args.size(); i++) {
if (auto h = value_hash(args[i]); h.has_value()) {
constants.push_back(*h);
}
}
for (auto& pair : kwargs) {
if (auto h = value_hash(pair.second); h.has_value()) {
constants.push_back(*value_hash(pair.first));
constants.push_back(*h);
}
}
// Compile and call
auto outputs =
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
@ -1017,7 +800,38 @@ void init_transforms(py::module_& m) {
const py::object& inputs,
const py::object& outputs,
bool shapeless) {
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
py::options options;
options.disable_function_signatures();
std::ostringstream doc;
auto name = fun.attr("__name__").cast<std::string>();
doc << name;
// Try to get the signature
auto inspect = py::module::import("inspect");
if (!inspect.attr("isbuiltin")(fun).cast<bool>()) {
doc << inspect.attr("signature")(fun)
.attr("__str__")()
.cast<std::string>();
}
// Try to get the doc string
if (auto d = fun.attr("__doc__"); py::isinstance<py::str>(d)) {
doc << "\n\n";
auto dstr = d.cast<std::string>();
// Add spaces to match first line indentation with remainder of
// docstring
int i = 0;
for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) {
doc << ' ';
}
doc << dstr;
}
auto doc_str = doc.str();
return py::cpp_function(
PyCompiledFun{fun, inputs, outputs, shapeless},
py::name(name.c_str()),
py::doc(doc_str.c_str()));
},
"fun"_a,
"inputs"_a = std::nullopt,

243
python/src/trees.cpp Normal file
View File

@ -0,0 +1,243 @@
// Copyright © 2023-2024 Apple Inc.
#include "python/src/trees.h"
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
std::function<void(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree) ||
py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (py::isinstance<py::dict>(subtree)) {
for (auto item : py::cast<py::dict>(subtree)) {
recurse(item.second);
}
} else {
visitor(subtree);
}
};
recurse(tree);
}
template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<py::object>& subtrees) {
int len = py::cast<T>(subtrees[0]).size();
for (auto& subtree : subtrees) {
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
}
}
}
py::object tree_map(
const std::vector<py::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform) {
std::function<py::object(const std::vector<py::object>&)> recurse;
recurse = [&](const std::vector<py::object>& subtrees) {
if (py::isinstance<py::list>(subtrees[0])) {
py::list l;
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::list>(subtrees[j])) {
items[j] = py::cast<py::list>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l.append(recurse(items));
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtrees[0])) {
// Check the rest of the subtrees
std::vector<py::object> items(subtrees.size());
int len = py::cast<py::tuple>(subtrees[0]).size();
py::tuple l(len);
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::tuple>(subtrees[j])) {
items[j] = py::cast<py::tuple>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l[i] = recurse(items);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::dict>(subtrees[0])) {
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
py::dict d;
for (auto item : py::cast<py::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::dict>(subtrees[j])) {
auto subdict = py::cast<py::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
throw std::invalid_argument(
"[tree_map] Tree is not a valid prefix tree of the first tree.");
}
items[j] = subdict[item.first];
} else {
items[j] = subtrees[j];
}
}
d[item.first] = recurse(items);
}
return py::cast<py::object>(d);
} else {
return transform(subtrees);
}
};
return recurse(trees);
}
py::object tree_map(
py::object tree,
std::function<py::object(py::handle)> transform) {
return tree_map({tree}, [&](std::vector<py::object> inputs) {
return transform(inputs[0]);
});
}
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor) {
std::function<py::object(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree)) {
auto l = py::cast<py::list>(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
return py::cast<py::object>(subtree);
} else if (py::isinstance<py::dict>(subtree)) {
auto d = py::cast<py::dict>(subtree);
for (auto item : d) {
d[item.first] = recurse(item.second);
}
return py::cast<py::object>(d);
} else if (py::isinstance<array>(subtree)) {
return visitor(subtree);
} else {
return py::cast<py::object>(subtree);
}
};
recurse(tree);
}
// Fill a pytree (recursive dict or list of dict or list)
// in place with the given arrays
// Non dict or list nodes are ignored
void tree_fill(py::object& tree, const std::vector<array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](py::handle node) { return py::cast(values[index++]); });
}
// Replace all the arrays from the src values with the dst values in the tree
void tree_replace(
py::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst) {
std::unordered_map<uintptr_t, array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](py::handle node) {
auto arr = py::cast<array>(node);
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
return py::cast(it->second);
}
return py::cast(arr);
});
}
std::vector<array> tree_flatten(py::object tree, bool strict /* = true */) {
std::vector<array> flat_tree;
tree_visit(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return flat_tree;
}
py::object tree_unflatten(
py::object tree,
const std::vector<array>& values,
int index /* = 0 */) {
return tree_map(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}
py::object structure_sentinel() {
static py::object sentinel;
if (sentinel.ptr() == nullptr) {
sentinel = py::capsule(&sentinel);
// probably not needed but this should make certain that we won't ever
// delete the sentinel
sentinel.inc_ref();
}
return sentinel;
}
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
bool strict /* = true */) {
auto sentinel = structure_sentinel();
std::vector<array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
return sentinel;
} else if (!strict) {
return py::cast<py::object>(obj);
} else {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return {flat_tree, structure};
}
py::object tree_unflatten_from_structure(
py::object structure,
const std::vector<array>& values,
int index /* = 0 */) {
auto sentinel = structure_sentinel();
return tree_map(structure, [&](py::handle obj) {
if (obj.is(sentinel)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}

60
python/src/trees.h Normal file
View File

@ -0,0 +1,60 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "mlx/array.h"
namespace py = pybind11;
using namespace mlx::core;
void tree_visit(py::object tree, std::function<void(py::handle)> visitor);
py::object tree_map(
const std::vector<py::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform);
py::object tree_map(
py::object tree,
std::function<py::object(py::handle)> transform);
void tree_visit_update(
py::object tree,
std::function<py::object(py::handle)> visitor);
/**
* Fill a pytree (recursive dict or list of dict or list) in place with the
* given arrays. */
void tree_fill(py::object& tree, const std::vector<array>& values);
/**
* Replace all the arrays from the src values with the dst values in the
* tree.
*/
void tree_replace(
py::object& tree,
const std::vector<array>& src,
const std::vector<array>& dst);
/**
* Flatten a tree into a vector of arrays. If strict is true, then the
* function will throw if the tree contains a leaf which is not an array.
*/
std::vector<array> tree_flatten(py::object tree, bool strict = true);
/**
* Unflatten a tree from a vector of arrays.
*/
py::object tree_unflatten(
py::object tree,
const std::vector<array>& values,
int index = 0);
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
bool strict = true);
py::object tree_unflatten_from_structure(
py::object structure,
const std::vector<array>& values,
int index = 0);

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import operator
import pickle
import unittest
import weakref
from itertools import permutations
@ -1440,6 +1441,15 @@ class TestArray(mlx_tests.MLXTestCase):
b @= a
self.assertTrue(mx.array_equal(a, b))
def test_load_from_pickled_np(self):
a = np.array([1, 2, 3], dtype=np.int32)
b = pickle.loads(pickle.dumps(a))
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
a = np.array([1.0, 2.0, 3.0], dtype=np.float16)
b = pickle.loads(pickle.dumps(a))
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
if __name__ == "__main__":
unittest.main()

View File

@ -415,6 +415,14 @@ class TestAutograd(mlx_tests.MLXTestCase):
_, vjps = mx.vjp(func, (arr,), (cotan,))
self.assertEqual(vjps[0].item(), 8.0)
def test_power_grad(self):
def fun(x, y):
res = x - y
return res**x
grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))
self.assertEqual(grad.item(), 1.0)
if __name__ == "__main__":
unittest.main()

View File

@ -539,6 +539,72 @@ class TestCompile(mlx_tests.MLXTestCase):
z = fun(mx.array(1), "two")
self.assertEqual(z.item(), 3)
# Test nested constant
@partial(mx.compile)
def fun(x, y):
if y[0][0] == 1:
return x + 1
else:
return x + 2
z = fun(mx.array(1), [[1]])
self.assertEqual(z.item(), 2)
z = fun(mx.array(1), [[0]])
self.assertEqual(z.item(), 3)
@partial(mx.compile)
def fun(x, a, b):
for ai in a:
for bi in b:
x = bi * x + ai
return x
z = fun(mx.array(1), [1, 1], [2])
self.assertEqual(z.item(), 7)
z = fun(mx.array(1), [1], [1, 2])
self.assertEqual(z.item(), 5)
counter = [0]
@partial(mx.compile)
def fun(x, y):
counter[0] += 1
return x + y
z = fun(mx.array(1), 1)
self.assertEqual(z.item(), 2)
z = fun(1, mx.array(1))
self.assertEqual(z.item(), 2)
self.assertEqual(counter[0], 2)
def test_compile_inf(self):
@mx.compile
def fun(x):
return mx.isinf(x + 2)
out = fun(mx.array([0.0]))
self.assertEqual(out.item(), False)
def test_unsupported_input_types(self):
class MyClass:
value = 1
@mx.compile
def fun(x, y):
return x + y.value
with self.assertRaises(ValueError):
out = fun(mx.array(0.0), MyClass())
with self.assertRaises(ValueError):
out = fun(mx.array(0.0), y=MyClass())
if __name__ == "__main__":
unittest.main()

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import math
import unittest
@ -388,13 +388,8 @@ class TestConv(mlx_tests.MLXTestCase):
_, outs_mx = mx.vjp(
f,
[
in_mx,
wt_mx,
],
[
ct_mx,
],
[in_mx, wt_mx],
[ct_mx],
)
pt_grad_in = F.grad.conv1d_input(
in_pt.shape,
@ -428,18 +423,218 @@ class TestConv(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
for idim, kdim, stride, padding, dilation in (
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
):
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
run_conv2D_grad(
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
)
def __conv_general_test(
self,
in_shape,
wt_shape,
stride=1,
padding=0,
kernel_dilation=1,
input_dilation=1,
groups=1,
flip=False,
np_dtype=np.float32,
atol=1e-5,
):
with self.subTest(
in_shape=in_shape,
wt_shape=wt_shape,
stride=stride,
padding=padding,
kernel_dilation=kernel_dilation,
input_dilation=input_dilation,
groups=groups,
flip=flip,
np_dtype=np_dtype,
):
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
lambda x: torch.from_numpy(np.moveaxis(x, -1, 1)).to("cpu"),
(in_np, wt_np),
)
out_mx = mx.conv_general(
in_mx,
wt_mx,
stride=stride,
padding=padding,
kernel_dilation=kernel_dilation,
input_dilation=input_dilation,
groups=groups,
flip=flip,
)
def conv_general_pt(
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
):
C = inp.size()[1]
ndim = inp.ndim - 2
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x
stride, padding, kernel_dilation, input_dilation = map(
map_ints, (stride, padding, kernel_dilation, input_dilation)
)
torch_convt_list = (
F.conv_transpose1d,
F.conv_transpose2d,
F.conv_transpose3d,
)
torch_conv_list = (F.conv1d, F.conv2d, F.conv3d)
conv_f = torch_conv_list[ndim - 1]
convt_f = torch_convt_list[ndim - 1]
if flip:
wt = torch.flip(wt, tuple(np.arange(2, wt.ndim)))
if not np.all(input_dilation == 1):
ones = torch.ones(
[C]
+ [
1,
]
* (ndim + 1)
).to(inp.dtype)
inp = convt_f(inp, ones, stride=input_dilation, groups=C)
return conv_f(
inp,
wt,
stride=stride,
padding=padding,
dilation=kernel_dilation,
groups=groups,
)
out_pt = conv_general_pt(
in_pt,
wt_pt,
stride=stride,
padding=padding,
kernel_dilation=kernel_dilation,
input_dilation=input_dilation,
groups=groups,
flip=flip,
)
out_pt = np.moveaxis(out_pt.numpy(), 1, -1)
self.assertEqual(out_mx.shape, out_pt.shape)
self.assertTrue(np.allclose(out_mx, out_pt, atol=atol))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_general(self):
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 5, 16)
stride = (1, 1)
padding = (2, 2)
kernel_dilation = (2, 3)
input_dilation = (1, 1)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 10, 16)
stride = (2, 3)
padding = (0, 0)
kernel_dilation = (3, 2)
input_dilation = (2, 4)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 10, 16)
stride = (2, 2)
padding = (3, 2)
kernel_dilation = (3, 2)
input_dilation = (2, 4)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 10, 16)
stride = (2, 3)
padding = (3, 2)
kernel_dilation = (3, 2)
input_dilation = (2, 5)
flip = False
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
in_shape = (2, 32, 32, 16)
wt_shape = (32, 5, 5, 16)
stride = (2, 3)
padding = (0, 0)
kernel_dilation = (3, 1)
input_dilation = (2, 5)
flip = True
self.__conv_general_test(
in_shape,
wt_shape,
stride,
padding,
kernel_dilation,
input_dilation,
flip=flip,
)
if __name__ == "__main__":

View File

@ -66,13 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase):
def test_save_and_load_safetensors(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
test_file = os.path.join(self.test_dir, "test.safetensors")
with self.assertRaises(Exception):
mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0})
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
mx.save_safetensors(
"test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
)
res = mx.load("test.safetensors", return_metadata=True)
res = mx.load(test_file, return_metadata=True)
self.assertEqual(len(res), 2)
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})

View File

@ -0,0 +1,45 @@
# Copyright © 2023-2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestMetal(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_memory_info(self):
old_limit = mx.metal.set_cache_limit(0)
a = mx.zeros((4096,))
mx.eval(a)
del a
self.assertEqual(mx.metal.get_cache_memory(), 0)
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
old_limit = mx.metal.set_memory_limit(10)
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
# Query active and peak memory
a = mx.zeros((4096,))
mx.eval(a)
active_mem = mx.metal.get_active_memory()
self.assertTrue(active_mem >= 4096 * 4)
b = mx.zeros((4096,))
mx.eval(b)
del b
new_active_mem = mx.metal.get_active_memory()
self.assertEqual(new_active_mem, active_mem)
peak_mem = mx.metal.get_peak_memory()
self.assertTrue(peak_mem >= 4096 * 8)
cache_mem = mx.metal.get_cache_memory()
self.assertTrue(cache_mem >= 4096 * 4)
if __name__ == "__main__":
unittest.main()

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import os
import tempfile
@ -8,7 +8,7 @@ import mlx.core as mx
import mlx.nn as nn
import mlx_tests
import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from mlx.utils import tree_flatten, tree_map
class TestBase(mlx_tests.MLXTestCase):
@ -665,7 +665,7 @@ class TestLayers(mlx_tests.MLXTestCase):
y_hat1 = nn.gelu_approx(x)
y_hat2 = nn.gelu_fast_approx(x)
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
self.assertLess(mx.abs(y - y_hat2).max(), 0.025)
def test_sin_pe(self):
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
@ -905,6 +905,228 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
def test_upsample(self):
b, h, w, c = 1, 2, 2, 1
scale_factor = 2
upsample_nearest = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=True
)
upsample_bilinear = nn.Upsample(
scale_factor=scale_factor, mode="linear", align_corners=True
)
upsample_nearest = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=True
)
upsample_bilinear_no_align_corners = nn.Upsample(
scale_factor=scale_factor, mode="linear", align_corners=False
)
upsample_nearest_no_align_corners = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=False
)
# Test single feature map, align corners
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
expected_nearest = mx.array(
[[[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]]]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0, 0.333333, 0.666667, 1],
[0.666667, 1, 1.33333, 1.66667],
[1.33333, 1.66667, 2, 2.33333],
[2, 2.33333, 2.66667, 3],
]
]
]
).transpose((0, 2, 3, 1))
# Test single feature map, no align corners
x = (
mx.arange(1, b * h * w * c + 1)
.reshape((b, c, h, w))
.transpose((0, 2, 3, 1))
)
expected_bilinear_no_align_corners = mx.array(
[
[
[
[1.0000, 1.2500, 1.7500, 2.0000],
[1.5000, 1.7500, 2.2500, 2.5000],
[2.5000, 2.7500, 3.2500, 3.5000],
[3.0000, 3.2500, 3.7500, 4.0000],
]
]
]
).transpose((0, 2, 3, 1))
expected_nearest_no_align_corners = mx.array(
[[[[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]]]
).transpose((0, 2, 3, 1))
self.assertTrue(
np.allclose(
upsample_nearest_no_align_corners(x), expected_nearest_no_align_corners
)
)
self.assertTrue(
np.allclose(
upsample_bilinear_no_align_corners(x),
expected_bilinear_no_align_corners,
)
)
# Test a more complex batch
b, h, w, c = 2, 3, 3, 2
scale_factor = 2
x = mx.arange((b * h * w * c)).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
upsample_nearest = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=True
)
upsample_bilinear = nn.Upsample(
scale_factor=scale_factor, mode="linear", align_corners=True
)
expected_nearest = mx.array(
[
[
[
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
],
[
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
],
],
[
[
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
],
[
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
],
],
]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0.0, 0.4, 0.8, 1.2, 1.6, 2.0],
[1.2, 1.6, 2.0, 2.4, 2.8, 3.2],
[2.4, 2.8, 3.2, 3.6, 4.0, 4.4],
[3.6, 4.0, 4.4, 4.8, 5.2, 5.6],
[4.8, 5.2, 5.6, 6.0, 6.4, 6.8],
[6.0, 6.4, 6.8, 7.2, 7.6, 8.0],
],
[
[9.0, 9.4, 9.8, 10.2, 10.6, 11.0],
[10.2, 10.6, 11.0, 11.4, 11.8, 12.2],
[11.4, 11.8, 12.2, 12.6, 13.0, 13.4],
[12.6, 13.0, 13.4, 13.8, 14.2, 14.6],
[13.8, 14.2, 14.6, 15.0, 15.4, 15.8],
[15.0, 15.4, 15.8, 16.2, 16.6, 17.0],
],
],
[
[
[18.0, 18.4, 18.8, 19.2, 19.6, 20.0],
[19.2, 19.6, 20.0, 20.4, 20.8, 21.2],
[20.4, 20.8, 21.2, 21.6, 22.0, 22.4],
[21.6, 22.0, 22.4, 22.8, 23.2, 23.6],
[22.8, 23.2, 23.6, 24.0, 24.4, 24.8],
[24.0, 24.4, 24.8, 25.2, 25.6, 26.0],
],
[
[27.0, 27.4, 27.8, 28.2, 28.6, 29.0],
[28.2, 28.6, 29.0, 29.4, 29.8, 30.2],
[29.4, 29.8, 30.2, 30.6, 31.0, 31.4],
[30.6, 31.0, 31.4, 31.8, 32.2, 32.6],
[31.8, 32.2, 32.6, 33.0, 33.4, 33.8],
[33.0, 33.4, 33.8, 34.2, 34.6, 35.0],
],
],
]
).transpose((0, 2, 3, 1))
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
# Test different height and width scale_factor
b, h, w, c = 1, 2, 2, 2
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
upsample_nearest = nn.Upsample(
scale_factor=(2, 3), mode="nearest", align_corners=True
)
upsample_bilinear = nn.Upsample(
scale_factor=(2, 3), mode="linear", align_corners=True
)
expected_nearest = mx.array(
[
[
[
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[2, 2, 2, 3, 3, 3],
[2, 2, 2, 3, 3, 3],
],
[
[4, 4, 4, 5, 5, 5],
[4, 4, 4, 5, 5, 5],
[6, 6, 6, 7, 7, 7],
[6, 6, 6, 7, 7, 7],
],
]
]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0, 0.2, 0.4, 0.6, 0.8, 1],
[0.666667, 0.866667, 1.06667, 1.26667, 1.46667, 1.66667],
[1.33333, 1.53333, 1.73333, 1.93333, 2.13333, 2.33333],
[2, 2.2, 2.4, 2.6, 2.8, 3],
],
[
[4, 4.2, 4.4, 4.6, 4.8, 5],
[4.66667, 4.86667, 5.06667, 5.26667, 5.46667, 5.66667],
[5.33333, 5.53333, 5.73333, 5.93333, 6.13333, 6.33333],
[6, 6.2, 6.4, 6.6, 6.8, 7],
],
]
]
).transpose((0, 2, 3, 1))
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
# Test repr
self.assertEqual(
str(nn.Upsample(scale_factor=2)),
"Upsample(scale_factor=2.0, mode='nearest', align_corners=False)",
)
self.assertEqual(
str(nn.Upsample(scale_factor=(2, 3))),
"Upsample(scale_factor=(2.0, 3.0), mode='nearest', align_corners=False)",
)
def test_pooling(self):
# Test 1d pooling
x = mx.array(

View File

@ -1047,6 +1047,11 @@ class TestOps(mlx_tests.MLXTestCase):
a = mx.arange(0, float("inf"), float("inf"))
with self.assertRaises(ValueError):
a = mx.arange(float("inf"), 1, float("inf"))
with self.assertRaises(ValueError):
a = mx.arange(float("inf"), 1, 5)
with self.assertRaises(ValueError):
INT_MAX = 2147483647
a = mx.arange(0, INT_MAX + 1, 1)
a = mx.arange(5)
expected = [0, 1, 2, 3, 4]
@ -1132,6 +1137,27 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertListEqual(a.tolist(), expected)
self.assertEqual(a.dtype, mx.int32)
a = mx.arange(0, 10, 100)
expected = [0]
self.assertListEqual(a.tolist(), expected)
self.assertEqual(a.dtype, mx.int32)
a = mx.arange(10, 0, 1)
expected = []
self.assertListEqual(a.tolist(), expected)
a = mx.arange(10, 0, float("inf"))
expected = []
self.assertListEqual(a.tolist(), expected)
a = mx.arange(0, 10, float("inf"))
expected = [0]
self.assertListEqual(a.tolist(), expected)
a = mx.arange(0, -10, float("-inf"))
expected = [0]
self.assertListEqual(a.tolist(), expected)
def test_unary_ops(self):
def test_ops(npop, mlxop, x, y, atol):
r_np = npop(x)
@ -1563,7 +1589,7 @@ class TestOps(mlx_tests.MLXTestCase):
shape = (3, 4, 5)
for dtype in ("int32", "float32"):
for axis in (None, 0, 1, 2):
for kth in (-2, 2):
for kth in (-2, 0, 2):
with self.subTest(dtype=dtype, axis=axis, kth=kth):
np.random.seed(0)
np_dtype = getattr(np, dtype)
@ -1579,13 +1605,16 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(c_np, c_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
top_k_mx = mx.topk(a_mx, kth, axis=axis)
self.assertTrue(np.all(c_np <= top_k_mx))
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
if kth >= 0:
d_np = np.take(b_mx, np.arange(kth), axis=axis)
self.assertTrue(np.all(d_np <= c_mx))
top_k_mx = mx.topk(a_mx, kth, axis=axis)
top_k_np = np.take(
np.partition(a_np, -kth, axis=axis), (-kth,), axis=axis
)
self.assertTrue(np.all(top_k_np <= top_k_mx))
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
N = a_mx.shape[axis] if axis is not None else a_mx.size
M = top_k_mx.shape[axis or 0]
self.assertEqual(M, (kth + N) % N)
@unittest.skipIf(
os.getenv("LOW_MEMORY", None) is not None,
@ -1906,12 +1935,16 @@ class TestOps(mlx_tests.MLXTestCase):
[[[[1]], [[2]], [[3]]]],
]
for array in arrays:
mx_arrays = [mx.atleast_1d(mx.array(x)) for x in arrays]
atleast_arrays = mx.atleast_1d(*mx_arrays)
for i, array in enumerate(arrays):
mx_res = mx.atleast_1d(mx.array(array))
np_res = np.atleast_1d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
def test_atleast_2d(self):
def compare_nested_lists(x, y):
@ -1936,12 +1969,16 @@ class TestOps(mlx_tests.MLXTestCase):
[[[[1]], [[2]], [[3]]]],
]
for array in arrays:
mx_arrays = [mx.atleast_2d(mx.array(x)) for x in arrays]
atleast_arrays = mx.atleast_2d(*mx_arrays)
for i, array in enumerate(arrays):
mx_res = mx.atleast_2d(mx.array(array))
np_res = np.atleast_2d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
def test_atleast_3d(self):
def compare_nested_lists(x, y):
@ -1966,12 +2003,16 @@ class TestOps(mlx_tests.MLXTestCase):
[[[[1]], [[2]], [[3]]]],
]
for array in arrays:
mx_arrays = [mx.atleast_3d(mx.array(x)) for x in arrays]
atleast_arrays = mx.atleast_3d(*mx_arrays)
for i, array in enumerate(arrays):
mx_res = mx.atleast_3d(mx.array(array))
np_res = np.atleast_3d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
if __name__ == "__main__":

View File

@ -328,6 +328,37 @@ class TestSchedulers(unittest.TestCase):
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_schedule_joiner(self):
boundaries = [2, 3, 4]
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
with self.assertRaises(ValueError):
opt.schedulers.join_schedules(schedules, boundaries)
boundaries = [2, 4]
schedule = opt.schedulers.join_schedules(schedules, boundaries)
self.assertEqual(schedule(0).item(), 3)
self.assertEqual(schedule(1).item(), 3)
self.assertEqual(schedule(2).item(), 4)
self.assertEqual(schedule(3).item(), 4)
self.assertEqual(schedule(5).item(), 5)
self.assertEqual(schedule(7).item(), 5)
def test_linear_warmup_with_cosine_decay(self):
warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)
cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)
cos_with_warmup = opt.schedulers.join_schedules(
[warmup_schedule, cosine_schedule], [101]
)
self.assertEqual(cos_with_warmup(0), 0.0)
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
optimizer = opt.Adam(learning_rate=cos_with_warmup)
for _ in range(100):
optimizer.update({}, {})
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
for _ in range(100):
optimizer.update({}, {})
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
def test_compile_with_schedule(self):
lr_schedule = opt.exponential_decay(1e-1, 0.9)
optimizer = opt.SGD(learning_rate=lr_schedule)

View File

@ -152,7 +152,7 @@ if __name__ == "__main__":
setup(
name="mlx",
version=get_version("0.3.0"),
version=get_version("0.5.0"),
author="MLX Contributors",
author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple silicon.",

View File

@ -1075,6 +1075,37 @@ TEST_CASE("test jvp from vjp") {
CHECK(compute_derivs(subtract));
CHECK(compute_derivs(power));
}
// Conditional selection element-wise op
{
auto condition = random::randint(0, 2, {5, 10});
auto x = random::uniform({5, 10});
auto y = random::uniform({5, 10});
eval(condition, x, y);
auto compute_derivs = [&condition, &x, &y](auto fn) {
auto fn_wrap = [&fn](std::vector<array> inputs) {
return std::vector<array>{
fn(inputs[0], inputs[1], inputs[2], default_device())};
};
// Compute vjp and add results
auto vjps = vjp(fn_wrap, {condition, x, y}, {ones(x.shape())}).second;
auto vjp_out = add(add(vjps[0], vjps[1]), vjps[2]);
// Compute jvp
array jvp_out =
jvp(fn_wrap,
{condition, x, y},
{ones(condition.shape()), ones(y.shape()), ones(x.shape())})
.second[0];
array result = array_equal(vjp_out, jvp_out);
return result.item<bool>();
};
CHECK(compute_derivs(where));
}
}
TEST_CASE("test complex gradients") {

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <array>
#include "doctest/doctest.h"
@ -473,41 +473,42 @@ TEST_CASE("test metal validation") {
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
}
TEST_CASE("test metal enable/disable cache") {
// Test enable metal cache
TEST_CASE("test metal memory info") {
// Test cache limits
{
metal::set_cache_enabled(true);
CHECK(metal::cache_enabled());
auto& a = metal::allocator();
auto size = 100;
auto buf = a.malloc(size, false);
// Release a
a.free(buf);
// Check size should equals to size
CHECK_EQ(static_cast<MTL::Buffer*>(buf.ptr())->length(), size);
auto old_limit = metal::set_cache_limit(0);
{
auto a = zeros({4096});
eval(a);
}
CHECK_EQ(metal::get_cache_memory(), 0);
CHECK_EQ(metal::set_cache_limit(old_limit), 0);
CHECK_EQ(metal::set_cache_limit(old_limit), old_limit);
}
// Test disable metal cache
// Test memory limits
{
metal::set_cache_enabled(false);
CHECK(!metal::cache_enabled());
auto old_limit = metal::set_memory_limit(10);
CHECK_EQ(metal::set_memory_limit(old_limit), 10);
CHECK_EQ(metal::set_memory_limit(old_limit), old_limit);
}
auto& a = metal::allocator();
auto size = 100;
auto buf = a.malloc(size, false);
auto buf_ptr = static_cast<MTL::Buffer*>(buf.ptr());
unsigned char first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
// Query active and peak memory
{
auto a = zeros({4096});
eval(a);
auto active_mem = metal::get_active_memory();
CHECK(active_mem >= 4096 * 4);
{
auto b = zeros({4096});
eval(b);
}
auto new_active_mem = metal::get_active_memory();
CHECK_EQ(new_active_mem, active_mem);
auto peak_mem = metal::get_peak_memory();
CHECK(peak_mem >= 4096 * 8);
// Release a
a.free(buf);
// If release successfully, the first byte should be different from the
// first byte before release
unsigned char new_first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
CHECK_NE(new_first_byte, first_byte);
auto cache_mem = metal::get_cache_memory();
CHECK(cache_mem >= 4096 * 4);
}
}

View File

@ -791,13 +791,13 @@ TEST_CASE("test reduction ops") {
constexpr float inf = std::numeric_limits<float>::infinity();
x = array({-inf, -inf});
WARN_EQ(logsumexp(x).item<float>(), -inf);
CHECK_EQ(logsumexp(x).item<float>(), -inf);
x = array({0.0f, -inf});
CHECK_EQ(logsumexp(x).item<float>(), 0.0f);
x = array({0.0f, inf});
WARN_EQ(logsumexp(x).item<float>(), inf);
CHECK_EQ(logsumexp(x).item<float>(), inf);
x = reshape(arange(6, float32), {2, 3});
@ -1858,6 +1858,14 @@ TEST_CASE("test scatter") {
out = scatter(in, inds, updates, 0);
CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item<bool>());
// Array scatters with col contiguous updates
in = zeros({4, 4}, float32);
inds = array({0, 1, 2, 3});
updates = transpose(reshape(arange(16, float32), {4, 1, 4}));
out = scatter(in, inds, updates, 0);
CHECK(array_equal(out, transpose(reshape(arange(16, float32), {4, 4})))
.item<bool>());
// Irregular strided index and reduce collision test
in = zeros({10}, float32);
inds = broadcast_to(array(3), {10});
@ -1877,10 +1885,10 @@ TEST_CASE("test scatter") {
// Irregularly strided updates test
in = ones({3, 3});
updates = broadcast_to(array({0, 0, 0}), {1, 3, 3});
updates = broadcast_to(array({2, 2, 2}), {1, 3, 3});
inds = array({0});
out = scatter(in, inds, updates, 0);
CHECK(array_equal(out, zeros({3, 3})).item<bool>());
CHECK(array_equal(out, ones({3, 3}) * 2).item<bool>());
// Along different axis
in = zeros({2, 3});
@ -2185,6 +2193,8 @@ TEST_CASE("test power") {
}
TEST_CASE("test where") {
const float inf = std::numeric_limits<float>::infinity();
array condition(true);
array x(1.0f);
array y(0.0f);
@ -2216,6 +2226,49 @@ TEST_CASE("test where") {
out = where(condition, x, y);
expected = array({1, 2, 2, 1}, {2, 2});
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
condition = array(true);
x = array({1, 2, 3});
y = array({3, 6, 13});
CHECK(array_equal(where(condition, x, y), array({1, 2, 3})).item<bool>());
condition = array(false);
x = array({1, 2, 3});
y = array({3, 6, 13});
CHECK(array_equal(where(condition, x, y), array({3, 6, 13})).item<bool>());
condition = array({1, 1, 0});
x = array({1, 2, 3});
y = array({11, 12, 13});
CHECK(array_equal(where(condition, x, y), array({1, 2, 13})).item<bool>());
condition = array({true, false}, {2, 1, 1});
x = array({1, 2, 3, 4}, {2, 1, 2});
y = array({11, 22, 33, 44}, {2, 2, 1});
expected = array({1, 2, 1, 2, 33, 33, 44, 44}, {2, 2, 2});
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
condition = array({true, false, false});
x = array({inf, 2.0, 3.0});
y = array({10.0, 20.0, -inf});
CHECK(array_equal(where(condition, x, y), array({inf, 20.0, -inf}))
.item<bool>());
// 4-dim optimized case.
condition = array({false});
x = array({1, 2}, {2, 1, 1, 1});
y = array({3, 4}, {1, 1, 2, 1});
CHECK(array_equal(where(condition, x, y), array({3, 4, 3, 4}, {2, 1, 2, 1}))
.item<bool>());
// 5-dim optimized case.
condition = array({true, false}, {2, 1, 1, 1, 1});
x = array({1, 2, 3, 4}, {2, 1, 1, 1, 2});
y = array({11, 22}, {1, 1, 2, 1, 1});
CHECK(array_equal(
where(condition, x, y),
array({1, 2, 1, 2, 11, 11, 22, 22}, {2, 1, 2, 1, 2}))
.item<bool>());
}
TEST_CASE("test stack") {
@ -2734,6 +2787,19 @@ TEST_CASE("test atleast_1d") {
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
}
TEST_CASE("test atleast_1d vector") {
auto x = std::vector<array>{
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
auto out = atleast_1d(x);
CHECK_EQ(out.size(), 3);
CHECK_EQ(out[0].ndim(), 1);
CHECK_EQ(out[0].shape(), std::vector<int>{1});
CHECK_EQ(out[1].ndim(), 1);
CHECK_EQ(out[1].shape(), std::vector<int>{3});
CHECK_EQ(out[2].ndim(), 2);
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
}
TEST_CASE("test atleast_2d") {
auto x = array(1);
auto out = atleast_2d(x);
@ -2751,6 +2817,19 @@ TEST_CASE("test atleast_2d") {
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
}
TEST_CASE("test atleast_2d vector") {
auto x = std::vector<array>{
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
auto out = atleast_2d(x);
CHECK_EQ(out.size(), 3);
CHECK_EQ(out[0].ndim(), 2);
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1});
CHECK_EQ(out[1].ndim(), 2);
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3});
CHECK_EQ(out[2].ndim(), 2);
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
}
TEST_CASE("test atleast_3d") {
auto x = array(1);
auto out = atleast_3d(x);
@ -2766,4 +2845,36 @@ TEST_CASE("test atleast_3d") {
out = atleast_3d(x);
CHECK_EQ(out.ndim(), 3);
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
}
}
TEST_CASE("test atleast_3d vector") {
auto x = std::vector<array>{
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
auto out = atleast_3d(x);
CHECK_EQ(out.size(), 3);
CHECK_EQ(out[0].ndim(), 3);
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1, 1});
CHECK_EQ(out[1].ndim(), 3);
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3, 1});
CHECK_EQ(out[2].ndim(), 3);
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1});
}
TEST_CASE("test topk") {
auto x = reshape(arange(10), {2, 5});
{
auto y = topk(x, 1, 1);
CHECK(array_equal(y, array({4, 9}, {2, 1})).item<bool>());
}
{
auto y = topk(x, 2, 0);
CHECK(array_equal(y, x).item<bool>());
}
{
auto y = topk(x, 1, 0);
CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item<bool>());
}
}

View File

@ -248,11 +248,9 @@ TEST_CASE("test random uniform") {
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float32);
if (is_available(float16)) {
x = random::uniform({}, float16);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float16);
}
x = random::uniform({}, float16);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float16);
x = random::uniform({0});
CHECK(array_equal(x, array({})).item<bool>());
@ -467,11 +465,9 @@ TEST_CASE("test random bernoulli") {
CHECK_EQ(x.dtype(), bool_);
// Bernoulli parameter can have floating point type
if (is_available(float16)) {
x = random::bernoulli(array(0.5, float16));
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), bool_);
}
x = random::bernoulli(array(0.5, float16));
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), bool_);
CHECK_THROWS(random::bernoulli(array(1, int32)));
@ -513,11 +509,9 @@ TEST_CASE("Test truncated normal") {
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float32);
if (is_available(float16)) {
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float16);
}
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.dtype(), float16);
// Requested shape
x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});

View File

@ -138,6 +138,70 @@ TEST_CASE("test simple vmap") {
CHECK(array_equal(out, x + y).item<bool>());
}
// vmap where (ternary op)
{
auto fun = [](std::vector<array> inputs) {
auto out = where(inputs[0], inputs[1], inputs[2]);
return std::vector<array>{out};
};
auto vfun = vmap(fun);
array cond({true, false}, {2, 1});
array x({1.0, 2.0}, {2, 1});
array y({2.0, 4.0}, {2, 1});
auto out = vfun({cond, x, y})[0];
CHECK(array_equal(out, array({1.0, 4.0}, {2, 1})).item<bool>());
cond = array({true, true, false}, {1, 3});
x = ones({2, 1, 3});
y = zeros({3, 2});
vfun = vmap(fun, {1, 2, 0});
out = vfun({cond, x, y})[0];
CHECK(
array_equal(out, array({1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0}, {3, 2, 2}))
.item<bool>());
vfun = vmap(fun, {1, 2, 0}, {1});
out = vfun({cond, x, y})[0];
CHECK(
array_equal(out, array({1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0}, {2, 3, 2}))
.item<bool>());
cond = array({true, false});
x = array(2.);
y = ones({3, 2});
vfun = vmap(fun, {-1, -1, 0});
out = vfun({cond, x, y})[0];
CHECK(array_equal(out, array({2, 1, 2, 1, 2, 1}, {3, 2})).item<bool>());
cond = array({true, false});
x = ones({3, 2});
y = array(2.);
vfun = vmap(fun, {-1, 0, -1});
out = vfun({cond, x, y})[0];
CHECK(array_equal(out, array({1, 2, 1, 2, 1, 2}, {3, 2})).item<bool>());
CHECK_THROWS_AS(vmap(fun, {-1, -1, -1}, {0}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {-1, 0, -1}, {-1}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {-1, -1, 0}, {-1}), std::invalid_argument);
CHECK_THROWS_AS(vmap(fun, {0, -1, -1}, {-1}), std::invalid_argument);
cond = array({true, false});
x = array(1.);
y = array(2.);
vfun = vmap(fun, {-1, -1, -1}, {-1});
out = vfun({cond, x, y})[0];
CHECK(array_equal(out, array({1.0, 2.0})).item<bool>());
cond = array({1, 1, 1, 0, 0, 0}, {3, 2, 1});
x = ones({3, 2, 1});
y = full({3, 2, 1}, 2);
vfun = vmap(vmap(fun));
out = vfun({cond, x, y})[0];
CHECK(array_equal(out, array({1, 1, 1, 2, 2, 2}, {3, 2, 1})).item<bool>());
}
// vmap with capturing closure
{
auto x = add(add(ones({2}), zeros({2})), zeros({2}));