mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-27 11:21:28 +08:00
Merge branch 'ml-explore:main' into main
This commit is contained in:
commit
c02602a4a1
@ -237,6 +237,14 @@ workflows:
|
||||
jobs:
|
||||
- mac_build_and_test
|
||||
- linux_build_and_test
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
and:
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
filters:
|
||||
tags:
|
||||
|
@ -11,7 +11,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops.
|
||||
@ -256,4 +256,4 @@ Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
limitations under the License.
|
||||
|
@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.3.0)
|
||||
set(MLX_VERSION 0.5.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@ -67,8 +67,6 @@ if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
elseif (MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
add_compile_definitions(_METAL_)
|
||||
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION
|
||||
|
10
README.md
10
README.md
@ -11,10 +11,12 @@ brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy.
|
||||
MLX also has a fully featured C++ API, which closely mirrors the Python API.
|
||||
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
|
||||
that closely follow PyTorch to simplify building more complex models.
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||
more complex models.
|
||||
|
||||
- **Composable function transformations**: MLX supports composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
|
@ -73,6 +73,7 @@ void time_unary_ops() {
|
||||
|
||||
void time_binary_ops() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto condition = random::randint(0, 2, {M, N, K});
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
@ -84,7 +85,9 @@ void time_binary_ops() {
|
||||
TIME(divide, a, b, device);
|
||||
TIME(maximum, a, b, device);
|
||||
TIME(minimum, a, b, device);
|
||||
TIME(where, condition, a, b, device);
|
||||
|
||||
condition = array({true});
|
||||
b = random::uniform({1});
|
||||
eval(b);
|
||||
TIMEM("scalar", add, a, b, device);
|
||||
@ -93,7 +96,9 @@ void time_binary_ops() {
|
||||
TIMEM("scalar", multiply, a, b, device);
|
||||
TIMEM("vector-scalar", divide, a, b, device);
|
||||
TIMEM("scalar-vector", divide, b, a, device);
|
||||
TIMEM("scalar-vector", where, condition, a, b, device);
|
||||
|
||||
condition = broadcast_to(array({true}), {1000, 100});
|
||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
eval(a, b);
|
||||
@ -101,6 +106,7 @@ void time_binary_ops() {
|
||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
|
||||
}
|
||||
|
||||
void time_strided_ops() {
|
||||
|
129
benchmarks/python/conv_bench.py
Normal file
129
benchmarks/python/conv_bench.py
Normal file
@ -0,0 +1,129 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding)
|
||||
f_pt = make_pt_conv_2D(strides, padding)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
|
||||
for N, H, W, C, kH, kW, O, strides, padding in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
@ -7,12 +7,14 @@ import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
def scatter(dst, x, idx):
|
||||
dst[idx] = x
|
||||
dst[*idx] = x
|
||||
mx.eval(dst)
|
||||
|
||||
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
|
||||
idx = []
|
||||
for idx_shape in idx_shapes:
|
||||
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
|
||||
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||
dst = mx.random.normal(dst_shape).astype(mx.float32)
|
||||
|
||||
@ -20,13 +22,15 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
|
||||
print(f"MLX: {runtime:.3f}ms")
|
||||
|
||||
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
def gather(dst, x, idx, device):
|
||||
dst[idx] = x
|
||||
dst[*idx] = x
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
|
||||
idx = []
|
||||
for idx_shape in idx_shapes:
|
||||
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||
|
||||
@ -45,9 +49,45 @@ if __name__ == "__main__":
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
|
||||
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
|
||||
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
|
||||
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
|
||||
dst_shapes = [
|
||||
(10, 64),
|
||||
(100_000, 64),
|
||||
(1_000_000, 64),
|
||||
(100_000,),
|
||||
(2_000_00,),
|
||||
(20_000_000,),
|
||||
(10000, 64),
|
||||
(100, 64),
|
||||
(100, 10_000, 64),
|
||||
(10, 100, 100, 21),
|
||||
(1_000, 1_000, 10),
|
||||
]
|
||||
idx_shapes = [
|
||||
[(1_000_000,)],
|
||||
[(1_000_000,)],
|
||||
[(100_000,)],
|
||||
[(1_000_000,)],
|
||||
[(20_000_000,)],
|
||||
[(20_000_000,)],
|
||||
[(1000000,)],
|
||||
[(10000000,)],
|
||||
[(1_000,)],
|
||||
[(10_000,)],
|
||||
[(1_000,), (1_000,)],
|
||||
]
|
||||
x_shapes = [
|
||||
(1_000_000, 64),
|
||||
(1_000_000, 64),
|
||||
(100_000, 64),
|
||||
(1_000_000,),
|
||||
(20_000_000,),
|
||||
(20_000_000,),
|
||||
(1000000, 64),
|
||||
(10000000, 64),
|
||||
(1_000, 10_000, 64),
|
||||
(10_000, 100, 100, 21),
|
||||
(1_000, 10),
|
||||
]
|
||||
|
||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 7.2 KiB After Width: | Height: | Size: 76 KiB |
BIN
docs/src/_static/mlx_logo_dark.png
Normal file
BIN
docs/src/_static/mlx_logo_dark.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 48 KiB |
@ -49,10 +49,12 @@ html_theme_options = {
|
||||
"repository_url": "https://github.com/ml-explore/mlx",
|
||||
"use_repository_button": True,
|
||||
"navigation_with_keys": False,
|
||||
"logo": {
|
||||
"image_light": "_static/mlx_logo.png",
|
||||
"image_dark": "_static/mlx_logo_dark.png",
|
||||
},
|
||||
}
|
||||
|
||||
html_logo = "_static/mlx_logo.png"
|
||||
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
|
@ -64,6 +64,7 @@ are the CPU and GPU.
|
||||
python/transforms
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/tree_utils
|
||||
|
14
docs/src/python/metal.rst
Normal file
14
docs/src/python/metal.rst
Normal file
@ -0,0 +1,14 @@
|
||||
Metal
|
||||
=====
|
||||
|
||||
.. currentmodule:: mlx.core.metal
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
is_available
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
@ -12,13 +12,24 @@ simple functions.
|
||||
:toctree: _autosummary_functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
elu
|
||||
gelu
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
glu
|
||||
hardswish
|
||||
leaky_relu
|
||||
log_sigmoid
|
||||
log_softmax
|
||||
mish
|
||||
prelu
|
||||
relu
|
||||
relu6
|
||||
selu
|
||||
softshrink
|
||||
sigmoid
|
||||
silu
|
||||
softmax
|
||||
softplus
|
||||
softshrink
|
||||
step
|
||||
tanh
|
||||
|
@ -40,3 +40,4 @@ Layers
|
||||
Softshrink
|
||||
Step
|
||||
Transformer
|
||||
Upsample
|
@ -35,6 +35,7 @@ Operations
|
||||
convolve
|
||||
conv1d
|
||||
conv2d
|
||||
conv_general
|
||||
cos
|
||||
cosh
|
||||
dequantize
|
||||
@ -56,6 +57,7 @@ Operations
|
||||
greater_equal
|
||||
identity
|
||||
inner
|
||||
isclose
|
||||
isnan
|
||||
isposinf
|
||||
isneginf
|
||||
@ -120,6 +122,8 @@ Operations
|
||||
tan
|
||||
tanh
|
||||
tensordot
|
||||
tile
|
||||
topk
|
||||
transpose
|
||||
tri
|
||||
tril
|
||||
|
@ -8,6 +8,8 @@ Schedulers
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
step_decay
|
||||
exponential_decay
|
||||
cosine_decay
|
||||
exponential_decay
|
||||
join_schedules
|
||||
linear_schedule
|
||||
step_decay
|
||||
|
@ -64,6 +64,7 @@ DEFAULT(Reshape)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Slice)
|
||||
|
@ -1,6 +1,9 @@
|
||||
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(COMPILER ${CMAKE_C_COMPILER})
|
||||
set(CLANG TRUE)
|
||||
else()
|
||||
set(COMPILER ${CMAKE_CXX_COMPILER})
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
@ -8,16 +11,16 @@ add_custom_command(
|
||||
COMMAND /bin/bash
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${CMAKE_CXX_COMPILER}
|
||||
${CMAKE_SOURCE_DIR}
|
||||
${COMPILER}
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${CLANG}
|
||||
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
compiled_preamble.h
|
||||
${CMAKE_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${CMAKE_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${CMAKE_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${CMAKE_SOURCE_DIR}/mlx/types/complex.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
ops.h
|
||||
)
|
||||
|
||||
@ -43,6 +46,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
|
@ -9,7 +9,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum BinaryOpType {
|
||||
enum class BinaryOpType {
|
||||
ScalarScalar,
|
||||
ScalarVector,
|
||||
VectorScalar,
|
||||
@ -20,17 +20,17 @@ enum BinaryOpType {
|
||||
BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
BinaryOpType bopt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1) {
|
||||
bopt = ScalarScalar;
|
||||
bopt = BinaryOpType::ScalarScalar;
|
||||
} else if (a.data_size() == 1 && b.flags().contiguous) {
|
||||
bopt = ScalarVector;
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
} else if (b.data_size() == 1 && a.flags().contiguous) {
|
||||
bopt = VectorScalar;
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
} else if (
|
||||
a.flags().row_contiguous && b.flags().row_contiguous ||
|
||||
a.flags().col_contiguous && b.flags().col_contiguous) {
|
||||
bopt = VectorVector;
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
} else {
|
||||
bopt = General;
|
||||
bopt = BinaryOpType::General;
|
||||
}
|
||||
return bopt;
|
||||
}
|
||||
@ -42,11 +42,11 @@ void set_binary_op_output_data(
|
||||
BinaryOpType bopt,
|
||||
bool donate_with_move = false) {
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
@ -61,7 +61,7 @@ void set_binary_op_output_data(
|
||||
b.flags());
|
||||
}
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
@ -76,7 +76,7 @@ void set_binary_op_output_data(
|
||||
a.flags());
|
||||
}
|
||||
break;
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
@ -97,7 +97,7 @@ void set_binary_op_output_data(
|
||||
a.flags());
|
||||
}
|
||||
break;
|
||||
case General:
|
||||
case BinaryOpType::General:
|
||||
if (a.is_donatable() && a.flags().row_contiguous &&
|
||||
a.itemsize() == out.itemsize() && a.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
@ -424,25 +424,25 @@ void binary_op(
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == ScalarScalar) {
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == ScalarVector) {
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == VectorScalar) {
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == VectorVector) {
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
return;
|
||||
}
|
||||
@ -475,17 +475,17 @@ void binary_op(
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = VectorVector;
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = VectorScalar;
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = ScalarVector;
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
@ -495,20 +495,20 @@ void binary_op(
|
||||
size_t stride;
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
stride = 1;
|
||||
bopt = General;
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
} else {
|
||||
stride = strides[dim - 1];
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
|
||||
break;
|
||||
default:
|
||||
|
@ -260,14 +260,14 @@ void binary_op(
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == ScalarScalar) {
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
|
||||
op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == ScalarVector) {
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
opsv(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
@ -278,7 +278,7 @@ void binary_op(
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == VectorScalar) {
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
opvs(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
@ -289,7 +289,7 @@ void binary_op(
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == VectorVector) {
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
opvv(
|
||||
a.data<T>(),
|
||||
b.data<T>(),
|
||||
@ -327,17 +327,17 @@ void binary_op(
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = VectorVector;
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = VectorScalar;
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = ScalarVector;
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
@ -347,20 +347,20 @@ void binary_op(
|
||||
size_t stride;
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
stride = 1;
|
||||
bopt = General;
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
} else {
|
||||
stride = strides[dim - 1];
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
|
||||
break;
|
||||
default:
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
@ -27,14 +28,16 @@ void slow_conv_1D(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const T* start_wt_ptr = wt.data<T>();
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(2); // In channels
|
||||
@ -61,12 +64,15 @@ void slow_conv_1D(
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||
|
||||
int ih = oh * wt_strides[0] - padding[0] + wh * wt_dilation[0];
|
||||
int wh_flip = flip ? (wH - wh - 1) : wh;
|
||||
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
|
||||
|
||||
if (ih >= 0 && ih < iH) {
|
||||
auto ih_div = std::div(ih, in_dilation[0]);
|
||||
|
||||
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(
|
||||
in_ptr[ih * in_stride_H + c * in_stride_C]) *
|
||||
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
|
||||
static_cast<float>(wt_ptr[c * wt_stride_C]);
|
||||
} // c
|
||||
|
||||
@ -90,14 +96,16 @@ void slow_conv_2D(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const T* st_wt_ptr = wt.data<T>();
|
||||
const T* st_in_ptr = in.data<T>();
|
||||
T* st_out_ptr = out.data<T>();
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int iW = in.shape(2); // Input spatial dim
|
||||
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int oW = out.shape(2); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
@ -120,6 +128,8 @@ void slow_conv_2D(
|
||||
const size_t out_stride_W = out.strides()[2];
|
||||
const size_t out_stride_O = out.strides()[3];
|
||||
|
||||
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
||||
|
||||
auto pt_conv_no_checks =
|
||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||
@ -131,8 +141,10 @@ void slow_conv_2D(
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int ih = ih_base + wh * wt_dilation[0];
|
||||
int iw = iw_base + ww * wt_dilation[1];
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
@ -153,25 +165,74 @@ void slow_conv_2D(
|
||||
} // o
|
||||
};
|
||||
|
||||
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
|
||||
|
||||
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
|
||||
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
|
||||
|
||||
int f_wgt_jump_h = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
||||
int f_wgt_jump_w = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||
|
||||
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||
|
||||
std::vector<int> base_h(f_out_jump_h);
|
||||
std::vector<int> base_w(f_out_jump_w);
|
||||
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
|
||||
|
||||
int wh_base = 0;
|
||||
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
||||
wh_base++;
|
||||
ih_loop += jump_h;
|
||||
}
|
||||
|
||||
base_h[i] = wh_base;
|
||||
}
|
||||
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
|
||||
|
||||
int ww_base = 0;
|
||||
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
||||
ww_base++;
|
||||
iw_loop += jump_w;
|
||||
}
|
||||
|
||||
base_w[j] = ww_base;
|
||||
}
|
||||
|
||||
auto pt_conv_all_checks =
|
||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||
|
||||
int ih_base = oh * wt_strides[0] - padding[0];
|
||||
int iw_base = ow * wt_strides[1] - padding[1];
|
||||
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int ih = ih_base + wh * wt_dilation[0];
|
||||
int iw = iw_base + ww * wt_dilation[1];
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
@ -191,13 +252,17 @@ void slow_conv_2D(
|
||||
};
|
||||
|
||||
int oH_border_0 = 0;
|
||||
int oH_border_1 = (padding[0] + wt_strides[0] + 1) / wt_strides[0];
|
||||
int oH_border_2 = (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0];
|
||||
int oH_border_1 =
|
||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
|
||||
int oH_border_2 = std::max(
|
||||
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
||||
int oH_border_3 = oH;
|
||||
|
||||
int oW_border_0 = 0;
|
||||
int oW_border_1 = (padding[1] + wt_strides[0] + 1) / wt_strides[1];
|
||||
int oW_border_2 = (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1];
|
||||
int oW_border_1 =
|
||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
|
||||
int oW_border_2 = std::max(
|
||||
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
||||
int oW_border_3 = oW;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
@ -246,15 +311,18 @@ void dispatch_slow_conv_1D(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
if (in.dtype() == float32) {
|
||||
return slow_conv_1D<float>(in, wt, out, padding, wt_strides, wt_dilation);
|
||||
return slow_conv_1D<float>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == float16) {
|
||||
return slow_conv_1D<float16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == bfloat16) {
|
||||
return slow_conv_1D<bfloat16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval] got unsupported data type.");
|
||||
@ -267,15 +335,18 @@ void dispatch_slow_conv_2D(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
if (in.dtype() == float32) {
|
||||
return slow_conv_2D<float>(in, wt, out, padding, wt_strides, wt_dilation);
|
||||
return slow_conv_2D<float>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == float16) {
|
||||
return slow_conv_2D<float16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == bfloat16) {
|
||||
return slow_conv_2D<bfloat16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval] got unsupported data type.");
|
||||
@ -493,13 +564,16 @@ void conv_1D_cpu(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
if (wt_dilation[0] == 1) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
||||
return explicit_gemm_conv_1D_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation);
|
||||
return dispatch_slow_conv_1D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
||||
void conv_2D_cpu(
|
||||
@ -508,8 +582,11 @@ void conv_2D_cpu(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
return dispatch_slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation);
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
return dispatch_slow_conv_2D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -523,12 +600,26 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
// 2D convolution
|
||||
if (in.ndim() == (2 + 2)) {
|
||||
return conv_2D_cpu(
|
||||
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_);
|
||||
}
|
||||
// 1D convolution
|
||||
else if (in.ndim() == (1 + 2)) {
|
||||
return conv_1D_cpu(
|
||||
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_);
|
||||
}
|
||||
// Throw error
|
||||
else {
|
||||
|
@ -87,6 +87,7 @@ DEFAULT(Reshape)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Sin)
|
||||
|
@ -7,6 +7,10 @@
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
namespace {
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
} // namespace
|
||||
|
||||
typedef union {
|
||||
int i;
|
||||
float f;
|
||||
@ -588,4 +592,11 @@ struct LogicalOr {
|
||||
};
|
||||
};
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
T operator()(bool condition, T x, T y) {
|
||||
return condition ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
72
mlx/backend/common/select.cpp
Normal file
72
mlx/backend/common/select.cpp
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void select_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int8:
|
||||
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int16:
|
||||
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int32:
|
||||
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int64:
|
||||
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
|
||||
break;
|
||||
case float16:
|
||||
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case float32:
|
||||
ternary_op<bool, float, float, float>(a, b, c, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Select::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
const auto& condition = inputs[0];
|
||||
const auto& a = inputs[1];
|
||||
const auto& b = inputs[2];
|
||||
select_op(condition, a, b, out, detail::Select());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
226
mlx/backend/common/ternary.h
Normal file
226
mlx/backend/common/ternary.h
Normal file
@ -0,0 +1,226 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO: Add support for more combinations of input types.
|
||||
enum class TernaryOpType {
|
||||
ScalarScalarScalar,
|
||||
General,
|
||||
};
|
||||
|
||||
TernaryOpType
|
||||
get_ternary_op_type(const array& a, const array& b, const array& c) {
|
||||
TernaryOpType topt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
|
||||
topt = TernaryOpType::ScalarScalarScalar;
|
||||
} else {
|
||||
topt = TernaryOpType::General;
|
||||
}
|
||||
return topt;
|
||||
}
|
||||
|
||||
void set_ternary_op_output_data(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
TernaryOpType topt,
|
||||
bool donate_with_move = false) {
|
||||
switch (topt) {
|
||||
case TernaryOpType::ScalarScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
||||
break;
|
||||
case TernaryOpType::General:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dims1(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t c_idx = 0;
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||
a_idx += a.strides()[0];
|
||||
b_idx += b.strides()[0];
|
||||
c_idx += c.strides()[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dims2(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t c_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||
a_idx += a.strides()[1];
|
||||
b_idx += b.strides()[1];
|
||||
c_idx += c.strides()[1];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dims3(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t c_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||
a_idx += a.strides()[2];
|
||||
b_idx += b.strides()[2];
|
||||
c_idx += c.strides()[2];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dims4(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
|
||||
U* dst = out.data<U>();
|
||||
size_t a_idx = 0;
|
||||
size_t b_idx = 0;
|
||||
size_t c_idx = 0;
|
||||
size_t out_idx = 0;
|
||||
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||
a_idx += a.strides()[3];
|
||||
b_idx += b.strides()[3];
|
||||
c_idx += c.strides()[3];
|
||||
}
|
||||
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||
c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
|
||||
}
|
||||
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
|
||||
}
|
||||
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
switch (out.ndim()) {
|
||||
case 1:
|
||||
ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
|
||||
return;
|
||||
case 2:
|
||||
ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
|
||||
return;
|
||||
case 3:
|
||||
ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
|
||||
return;
|
||||
case 4:
|
||||
ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
|
||||
return;
|
||||
}
|
||||
|
||||
const T1* a_ptr = a.data<T1>();
|
||||
const T2* b_ptr = b.data<T2>();
|
||||
const T3* c_ptr = c.data<T3>();
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < out.size(); i++) {
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||
int c_idx = elem_to_loc(i, c.shape(), c.strides());
|
||||
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||
void ternary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op) {
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
|
||||
// The full computation is scalar-scalar-scalar so we call the base op once.
|
||||
if (topt == TernaryOpType::ScalarScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
|
||||
return;
|
||||
}
|
||||
|
||||
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
@ -4,7 +4,7 @@ add_custom_command(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${CMAKE_C_COMPILER}
|
||||
${CMAKE_SOURCE_DIR}
|
||||
${PROJECT_SOURCE_DIR}
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
kernels/compiled_preamble.h
|
||||
kernels/unary.h
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
@ -23,16 +23,6 @@ void* Buffer::raw_ptr() {
|
||||
|
||||
namespace metal {
|
||||
|
||||
static bool cache_enabled_ = true;
|
||||
|
||||
bool cache_enabled() {
|
||||
return cache_enabled_;
|
||||
}
|
||||
|
||||
void set_cache_enabled(bool enabled) {
|
||||
cache_enabled_ = enabled;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
BufferCache::BufferCache(MTL::Device* device)
|
||||
@ -158,9 +148,23 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
buffer_cache_(device_),
|
||||
peak_allocated_size_(0),
|
||||
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
|
||||
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
|
||||
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()),
|
||||
max_pool_size_(block_limit_) {}
|
||||
|
||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
std::swap(limit, max_pool_size_);
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
|
||||
std::swap(limit, block_limit_);
|
||||
relaxed_ = relaxed;
|
||||
gc_limit_ = std::min(
|
||||
block_limit_,
|
||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
|
||||
return limit;
|
||||
};
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
// Metal doesn't like empty buffers
|
||||
@ -175,10 +179,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
|
||||
// Try the cache
|
||||
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
|
||||
size_t pool_size = get_cache_memory();
|
||||
if (!buf) {
|
||||
size_t mem_required = get_active_memory() + pool_size + size;
|
||||
|
||||
// If there is too much memory pressure, fail (likely causes a wait).
|
||||
if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) {
|
||||
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
@ -186,10 +192,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
|
||||
// If we have a lot of memory pressure, check if we can reclaim some memory
|
||||
// from the cache
|
||||
if (device_->currentAllocatedSize() + size >= gc_limit_) {
|
||||
size_t min_bytes_to_free =
|
||||
size + device_->currentAllocatedSize() - gc_limit_;
|
||||
buffer_cache_.release_cached_buffers(min_bytes_to_free);
|
||||
if (mem_required >= gc_limit_) {
|
||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||
}
|
||||
|
||||
// Allocate new buffer if needed
|
||||
@ -198,15 +202,22 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
buf = device_->newBuffer(size, res_opt);
|
||||
}
|
||||
|
||||
peak_allocated_size_ =
|
||||
std::max(peak_allocated_size_, device_->currentAllocatedSize());
|
||||
// Maintain the cache below the requested limit
|
||||
if (pool_size >= max_pool_size_) {
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
buffer_cache_.release_cached_buffers(pool_size - max_pool_size_);
|
||||
}
|
||||
|
||||
active_memory_ += buf->length();
|
||||
peak_memory_ = std::max(peak_memory_, active_memory_);
|
||||
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
if (cache_enabled()) {
|
||||
active_memory_ -= buf->length();
|
||||
if (max_pool_size_ > 0) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
buf->release();
|
||||
@ -218,6 +229,22 @@ MetalAllocator& allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return allocator().set_cache_limit(limit);
|
||||
}
|
||||
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
|
||||
return allocator().set_memory_limit(limit, relaxed);
|
||||
}
|
||||
size_t get_active_memory() {
|
||||
return allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return allocator().get_peak_memory();
|
||||
}
|
||||
size_t get_cache_memory() {
|
||||
return allocator().get_cache_memory();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@ -24,6 +24,9 @@ class BufferCache {
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
void recycle_to_cache(MTL::Buffer* buf);
|
||||
void release_cached_buffers(size_t min_bytes_to_free);
|
||||
size_t pool_size() {
|
||||
return pool_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
@ -54,6 +57,17 @@ class MetalAllocator : public allocator::Allocator {
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
size_t get_active_memory() {
|
||||
return active_memory_;
|
||||
};
|
||||
size_t get_peak_memory() {
|
||||
return peak_memory_;
|
||||
};
|
||||
size_t get_cache_memory() {
|
||||
return buffer_cache_.pool_size();
|
||||
};
|
||||
size_t set_cache_limit(size_t limit);
|
||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||
|
||||
private:
|
||||
MTL::Device* device_;
|
||||
@ -64,9 +78,12 @@ class MetalAllocator : public allocator::Allocator {
|
||||
BufferCache buffer_cache_;
|
||||
|
||||
// Allocation stats
|
||||
size_t peak_allocated_size_;
|
||||
size_t block_limit_;
|
||||
size_t gc_limit_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
size_t max_pool_size_;
|
||||
bool relaxed_{true};
|
||||
};
|
||||
|
||||
MetalAllocator& allocator();
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@ -7,80 +7,72 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void explicit_gemm_conv_1D_gpu(
|
||||
template <int N>
|
||||
void explicit_gemm_conv_ND_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<1>& conv_params) {
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {
|
||||
conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.C};
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
const MLXConvParams<N>& conv_params) {
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape = {
|
||||
static_cast<int>(out.size() / conv_params.O),
|
||||
static_cast<int>(wt.size() / conv_params.O)};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
// Fill with zeros
|
||||
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, in_unfolded, 1);
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {
|
||||
conv_params.N, conv_params.oS[0], conv_params.wS[0], conv_params.C};
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * conv_params.str[0],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2]};
|
||||
auto flags = in_padded.flags();
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
int tgp_y = 256 / tgp_x;
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {
|
||||
conv_params.N * conv_params.oS[0], conv_params.wS[0] * conv_params.C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
std::vector<array> copies;
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_strided,
|
||||
/*a = */ in_unfolded,
|
||||
/*b = */ wt,
|
||||
/*c = */ out,
|
||||
/*M = */ strided_reshape[0],
|
||||
/*M = */ unfolded_shape[0],
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ strided_reshape[1],
|
||||
/*K = */ unfolded_shape[1],
|
||||
/*batch_size_out = */ 1,
|
||||
/*a_cols = */ strided_reshape[1],
|
||||
/*b_cols = */ strided_reshape[1],
|
||||
/*a_cols = */ unfolded_shape[1],
|
||||
/*b_cols = */ unfolded_shape[1],
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/*copies = */ copies);
|
||||
@ -94,7 +86,9 @@ void conv_1D_gpu(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
// Make conv params
|
||||
MLXConvParams<1> conv_params{
|
||||
/* const int N = */ in.shape(0),
|
||||
@ -105,24 +99,19 @@ void conv_1D_gpu(
|
||||
/* const int oS[NDIM] = */ {out.shape(1)},
|
||||
/* const int str[NDIM] = */ {wt_strides[0]},
|
||||
/* const int pad[NDIM] = */ {padding[0]},
|
||||
/* const int dil[NDIM] = */ {wt_dilation[0]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
||||
};
|
||||
/* const int groups = */ 1,
|
||||
/* const bool flip = */ flip};
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
if (wt_dilation[0] == 1) {
|
||||
explicit_gemm_conv_1D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to fallback conv
|
||||
else {
|
||||
throw std::invalid_argument("[conv_1D_gpu] Dilation needs to be 1.");
|
||||
}
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
void slow_conv_2D_gpu(
|
||||
@ -168,113 +157,262 @@ void implicit_gemm_conv_2D_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
// Deduce implicit gemm size
|
||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
int implicit_N = conv_params.O;
|
||||
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
||||
|
||||
// Determine block and warp tiles
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
|
||||
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
|
||||
int bk = 16;
|
||||
|
||||
if (implicit_N <= 16) {
|
||||
bn = 8;
|
||||
wm = 4;
|
||||
wn = 1;
|
||||
}
|
||||
|
||||
int tn = (implicit_N + bn - 1) / bn;
|
||||
int tm = (implicit_M + bm - 1) / bm;
|
||||
int swizzle_log = 0;
|
||||
|
||||
// Fix small channel specialization
|
||||
int n_channel_specialization = 0;
|
||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
|
||||
|
||||
if (conv_params.C <= 2) {
|
||||
gemm_k_iters = (implicit_K + bk - 1) / bk;
|
||||
n_channel_specialization = conv_params.C;
|
||||
} else if (conv_params.C <= 4) {
|
||||
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
|
||||
n_channel_specialization = conv_params.C;
|
||||
}
|
||||
|
||||
bool small_filter = (!n_channel_specialization) &&
|
||||
(conv_params.wS[0] <= 16 && conv_params.wS[1] <= 16);
|
||||
|
||||
// Fix host side helper params
|
||||
int sign = (conv_params.flip ? -1 : 1);
|
||||
int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
|
||||
int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
|
||||
|
||||
int inp_jump_w = sign * ijw;
|
||||
int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
|
||||
int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
|
||||
sign * (conv_params.wS[1] - 1) * ijw;
|
||||
|
||||
// Build implicit gemm params
|
||||
ImplicitGemmConv2DParams gemm_params{
|
||||
/* const int M = */ implicit_M,
|
||||
/* const int N = */ implicit_N,
|
||||
/* const int K = */ implicit_K,
|
||||
|
||||
/* const int gemm_k_iterations = */ gemm_k_iters,
|
||||
|
||||
/* const int inp_jump_w = */ inp_jump_w,
|
||||
/* const int inp_jump_h = */ inp_jump_h,
|
||||
/* const int inp_jump_c = */ inp_jump_c,
|
||||
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int swizzle_log = */ swizzle_log};
|
||||
|
||||
// Determine kernel
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
|
||||
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
|
||||
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
|
||||
: "l")
|
||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
int implicit_N = conv_params.O;
|
||||
|
||||
size_t grid_dim_x = (implicit_N + bn - 1) / bn;
|
||||
size_t grid_dim_y = (implicit_M + bm - 1) / bm;
|
||||
// Deduce grid launch dimensions
|
||||
int tile = 1 << swizzle_log;
|
||||
size_t grid_dim_y = (tm + tile - 1) / tile;
|
||||
size_t grid_dim_x = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void explicit_gemm_conv_2D_gpu(
|
||||
void implicit_gemm_conv_2D_general_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
// Pad input
|
||||
std::vector<int> padded_shape = {
|
||||
conv_params.N,
|
||||
conv_params.iS[0] + 2 * conv_params.pad[0],
|
||||
conv_params.iS[1] + 2 * conv_params.pad[1],
|
||||
conv_params.C};
|
||||
array in_padded(padded_shape, in.dtype(), nullptr, {});
|
||||
// Deduce implicit gemm size
|
||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
int implicit_N = conv_params.O;
|
||||
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
||||
|
||||
// Fill with zeros
|
||||
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
|
||||
// Determine block and warp tiles
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
||||
conv_params.pad[1] * in_padded.strides()[2];
|
||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||
in_padded_slice.copy_shared_buffer(
|
||||
in_padded,
|
||||
in_padded.strides(),
|
||||
in_padded.flags(),
|
||||
in_padded_slice.size(),
|
||||
data_offset);
|
||||
// Make jump params
|
||||
int f_wgt_jump_h =
|
||||
std::lcm(conv_params.idil[0], conv_params.kdil[0]) / conv_params.kdil[0];
|
||||
int f_wgt_jump_w =
|
||||
std::lcm(conv_params.idil[1], conv_params.kdil[1]) / conv_params.kdil[1];
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
||||
int f_out_jump_h =
|
||||
std::lcm(conv_params.idil[0], conv_params.str[0]) / conv_params.str[0];
|
||||
int f_out_jump_w =
|
||||
std::lcm(conv_params.idil[1], conv_params.str[1]) / conv_params.str[1];
|
||||
|
||||
// Make strided view
|
||||
std::vector<int> strided_shape = {
|
||||
conv_params.N,
|
||||
conv_params.oS[0],
|
||||
conv_params.oS[1],
|
||||
conv_params.wS[0],
|
||||
conv_params.wS[1],
|
||||
conv_params.C};
|
||||
int adj_out_h = (conv_params.oS[0] + f_out_jump_h - 1) / f_out_jump_h;
|
||||
int adj_out_w = (conv_params.oS[1] + f_out_jump_w - 1) / f_out_jump_w;
|
||||
int adj_out_hw = adj_out_h * adj_out_w;
|
||||
int adj_implicit_m = conv_params.N * adj_out_hw;
|
||||
|
||||
std::vector<size_t> strided_strides = {
|
||||
in_padded.strides()[0],
|
||||
in_padded.strides()[1] * conv_params.str[0],
|
||||
in_padded.strides()[2] * conv_params.str[1],
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2],
|
||||
in_padded.strides()[3]};
|
||||
auto flags = in_padded.flags();
|
||||
Conv2DGeneralJumpParams jump_params{
|
||||
/* const int f_wgt_jump_h = */ f_wgt_jump_h,
|
||||
/* const int f_wgt_jump_w = */ f_wgt_jump_w,
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
in_padded, strided_strides, flags, in_strided_view.size(), 0);
|
||||
/* const int f_out_jump_h = */ f_out_jump_h,
|
||||
/* const int f_out_jump_w = */ f_out_jump_w,
|
||||
|
||||
// Materialize strided view
|
||||
std::vector<int> strided_reshape = {
|
||||
conv_params.N * conv_params.oS[0] * conv_params.oS[1],
|
||||
conv_params.wS[0] * conv_params.wS[1] * conv_params.C};
|
||||
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
|
||||
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
|
||||
/* const int adj_out_h = */ adj_out_h,
|
||||
/* const int adj_out_w = */ adj_out_w,
|
||||
/* const int adj_out_hw = */ adj_out_hw,
|
||||
/* const int adj_implicit_m = */ adj_implicit_m};
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_padded, in_strided};
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_strided,
|
||||
/*b = */ wt,
|
||||
/*c = */ out,
|
||||
/*M = */ strided_reshape[0],
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ strided_reshape[1],
|
||||
/*batch_size_out = */ 1,
|
||||
/*a_cols = */ strided_reshape[1],
|
||||
/*b_cols = */ strided_reshape[1],
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/*copies = */ copies);
|
||||
// Make base info
|
||||
std::vector<Conv2DGeneralBaseInfo> base_h(f_out_jump_h);
|
||||
std::vector<Conv2DGeneralBaseInfo> base_w(f_out_jump_w);
|
||||
|
||||
int jump_h = conv_params.flip ? -conv_params.kdil[0] : conv_params.kdil[0];
|
||||
int jump_w = conv_params.flip ? -conv_params.kdil[1] : conv_params.kdil[1];
|
||||
|
||||
int init_h =
|
||||
(conv_params.flip ? (conv_params.wS[0] - 1) * conv_params.kdil[0] : 0);
|
||||
int init_w =
|
||||
(conv_params.flip ? (conv_params.wS[1] - 1) * conv_params.kdil[1] : 0);
|
||||
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * conv_params.str[0] - conv_params.pad[0] + init_h;
|
||||
|
||||
int wh_base = 0;
|
||||
while (wh_base < conv_params.wS[0] && ih_loop % conv_params.idil[0] != 0) {
|
||||
wh_base++;
|
||||
ih_loop += jump_h;
|
||||
}
|
||||
|
||||
int wh_size =
|
||||
((conv_params.wS[0] - wh_base) + f_wgt_jump_h - 1) / f_wgt_jump_h;
|
||||
base_h[i] = {wh_base, wh_size};
|
||||
}
|
||||
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * conv_params.str[1] - conv_params.pad[1] + init_w;
|
||||
|
||||
int ww_base = 0;
|
||||
while (ww_base < conv_params.wS[1] && iw_loop % conv_params.idil[1] != 0) {
|
||||
ww_base++;
|
||||
iw_loop += jump_w;
|
||||
}
|
||||
|
||||
int ww_size =
|
||||
((conv_params.wS[1] - ww_base) + f_wgt_jump_w - 1) / f_wgt_jump_w;
|
||||
base_w[j] = {ww_base, ww_size};
|
||||
}
|
||||
|
||||
// Collect block sizes
|
||||
int bm = adj_implicit_m >= 8192 && conv_params.C >= 64 ? 64 : 32;
|
||||
int bn = (bm == 64 && implicit_N >= 64) ? 64 : 32;
|
||||
int bk = 16;
|
||||
|
||||
int tn = (implicit_N + bn - 1) / bn;
|
||||
int tm = (adj_implicit_m + bm - 1) / bm;
|
||||
int swizzle_log = 0;
|
||||
|
||||
// Get channel iteration info
|
||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||
int gemm_k_iters = channel_k_iters;
|
||||
|
||||
// Fix host side helper params
|
||||
int sign = (conv_params.flip ? -1 : 1);
|
||||
int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
|
||||
int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
|
||||
|
||||
int inp_jump_w = sign * ijw;
|
||||
int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
|
||||
int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
|
||||
sign * (conv_params.wS[1] - 1) * ijw;
|
||||
|
||||
// Build implicit gemm params
|
||||
ImplicitGemmConv2DParams gemm_params{
|
||||
/* const int M = */ implicit_M,
|
||||
/* const int N = */ implicit_N,
|
||||
/* const int K = */ implicit_K,
|
||||
|
||||
/* const int gemm_k_iterations = */ gemm_k_iters,
|
||||
|
||||
/* const int inp_jump_w = */ inp_jump_w,
|
||||
/* const int inp_jump_h = */ inp_jump_h,
|
||||
/* const int inp_jump_c = */ inp_jump_c,
|
||||
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int swizzle_log = */ swizzle_log};
|
||||
|
||||
// Determine kernel
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
int tile = 1 << swizzle_log;
|
||||
size_t grid_dim_y = (tm + tile - 1) / tile;
|
||||
size_t grid_dim_x = tn * tile;
|
||||
size_t grid_dim_z = f_out_jump_h * f_out_jump_w;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
|
||||
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void winograd_conv_2D_gpu(
|
||||
@ -299,6 +437,7 @@ void winograd_conv_2D_gpu(
|
||||
// Fill with zeros
|
||||
array zero_arr = array(0, in.dtype());
|
||||
copy_gpu(zero_arr, in_padded, CopyType::Scalar, s);
|
||||
copies_w.push_back(zero_arr);
|
||||
|
||||
// Pick input slice from padded
|
||||
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
||||
@ -327,7 +466,8 @@ void winograd_conv_2D_gpu(
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int str[NDIM] = */ {1, 1},
|
||||
/* const int pad[NDIM] = */ {0, 0},
|
||||
/* const int dil[NDIM] = */ {1, 1},
|
||||
/* const int kdil[NDIM] = */ {1, 1},
|
||||
/* const int idil[NDIM] = */ {1, 1},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in_padded.strides()[0],
|
||||
in_padded.strides()[1],
|
||||
@ -337,6 +477,8 @@ void winograd_conv_2D_gpu(
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
/* const int groups = */ 1,
|
||||
/* const bool flip = */ false,
|
||||
};
|
||||
|
||||
int O_c = conv_params.O;
|
||||
@ -460,6 +602,8 @@ void conv_2D_gpu(
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
MLXConvParams<2> conv_params{
|
||||
@ -471,37 +615,47 @@ void conv_2D_gpu(
|
||||
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
|
||||
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
|
||||
/* const int pad[NDIM] = */ {padding[0], padding[1]},
|
||||
/* const int dil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
/* const int groups = */ 1,
|
||||
/* const bool flip = */ flip,
|
||||
};
|
||||
|
||||
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
bool inp_large = (conv_params.in_strides[0] >= 1ul << 18);
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 512;
|
||||
bool channels_med = (conv_params.C + conv_params.O) >= 256;
|
||||
|
||||
// Direct to winograd conv
|
||||
if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
|
||||
conv_params.C >= 64 && conv_params.O >= 64 && conv_params.wS[0] == 3 &&
|
||||
conv_params.wS[1] == 3 && conv_params.str[0] == 1 &&
|
||||
conv_params.str[1] == 1 && conv_params.dil[0] == 1 &&
|
||||
conv_params.dil[1] == 1) {
|
||||
winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
|
||||
(channels_large || (channels_med && inp_large))) {
|
||||
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
||||
}
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
else if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0) {
|
||||
implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
|
||||
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) {
|
||||
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to explicit gemm conv
|
||||
else if (wt_dilation[0] == 1 && wt_dilation[1] == 1) {
|
||||
explicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
// Direct to fallback conv
|
||||
else {
|
||||
slow_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
@ -532,11 +686,31 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// 2D conv
|
||||
if (out.ndim() == 4) {
|
||||
conv_2D_gpu(
|
||||
s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_, copies);
|
||||
s,
|
||||
d,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_,
|
||||
copies);
|
||||
}
|
||||
// 1D conv
|
||||
else if (out.ndim() == 3) {
|
||||
conv_1D_gpu(s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||
conv_1D_gpu(
|
||||
s,
|
||||
d,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_);
|
||||
}
|
||||
// Throw error
|
||||
else {
|
||||
|
@ -142,7 +142,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Get kernel name
|
||||
std::ostringstream kname;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
bool index_nd1_specialization = (idx_ndim == 1);
|
||||
|
||||
// Bail from fast path (1d index specialization) if scatter dims aren't
|
||||
// the outermost dims and contiguous since update access won't be raster
|
||||
// order.
|
||||
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
|
||||
index_nd1_specialization &= (axes_[i] == i);
|
||||
}
|
||||
|
||||
// Bail from fast path (1d index specialization) if any of the dims are
|
||||
// broadcasted, since we can't rely on linear indexing in that case.
|
||||
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
|
||||
index_nd1_specialization &= inputs[i].flags().row_contiguous;
|
||||
}
|
||||
|
||||
if (index_nd1_specialization) {
|
||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||
} else {
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
}
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
kname << "_none";
|
||||
@ -170,85 +191,106 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Collect all idx shapes and strides into one place
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
// Set update info
|
||||
size_t upd_ndim = upd.ndim();
|
||||
uint upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
|
||||
if (index_nd1_specialization) {
|
||||
bool upd_col_contiguous = upd.flags().col_contiguous;
|
||||
compute_encoder->setBytes(
|
||||
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
} else {
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
}
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
}
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
} else {
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
}
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -3,11 +3,13 @@ set(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
)
|
||||
|
||||
@ -27,6 +29,7 @@ set(
|
||||
"scan"
|
||||
"softmax"
|
||||
"sort"
|
||||
"ternary"
|
||||
"unary"
|
||||
"gather"
|
||||
"scatter"
|
||||
@ -48,11 +51,7 @@ endfunction(build_kernel_base)
|
||||
|
||||
function(build_kernel KERNEL)
|
||||
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
||||
set(HEADERS_PADDED ${HEADERS})
|
||||
if(${KERNEL} STREQUAL "conv")
|
||||
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/conv.h)
|
||||
endif()
|
||||
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS_PADDED}")
|
||||
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS}")
|
||||
endfunction(build_kernel)
|
||||
|
||||
foreach(KERNEL ${KERNELS})
|
||||
|
@ -11,8 +11,6 @@ template <typename U>
|
||||
struct IndexValPair {
|
||||
uint32_t index;
|
||||
U val;
|
||||
|
||||
IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
@ -65,10 +63,10 @@ struct ArgMax {
|
||||
|
||||
template <typename U>
|
||||
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||
return IndexValPair<U>(
|
||||
return IndexValPair<U>{
|
||||
simd_shuffle_down(data.index, delta),
|
||||
simd_shuffle_down(data.val, delta)
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -82,7 +80,6 @@ template <typename T, typename Op, int N_READS>
|
||||
const device size_t& ndim [[buffer(5)]],
|
||||
const device size_t& axis_stride [[buffer(6)]],
|
||||
const device size_t& axis_size [[buffer(7)]],
|
||||
threadgroup IndexValPair<T> *local_data [[threadgroup(0)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
@ -111,7 +108,9 @@ template <typename T, typename Op, int N_READS>
|
||||
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
|
||||
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
|
||||
|
||||
IndexValPair<T> best(0, Op::init);
|
||||
IndexValPair<T> best{0, Op::init};
|
||||
|
||||
threadgroup IndexValPair<T> local_data[32];
|
||||
|
||||
// Loop over the reduction axis in lsize*N_READS buckets
|
||||
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
||||
@ -172,7 +171,6 @@ template <typename T, typename Op, int N_READS>
|
||||
const device size_t& ndim [[buffer(5)]], \
|
||||
const device size_t& axis_stride [[buffer(6)]], \
|
||||
const device size_t& axis_size [[buffer(7)]], \
|
||||
threadgroup IndexValPair<itype> *local_data [[threadgroup(0)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
|
@ -2,16 +2,6 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_ss(
|
||||
device const T* a,
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
typedef half float16_t;
|
||||
|
@ -1,481 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int vec_size,
|
||||
int tgp_size,
|
||||
int tgp_padding = 0>
|
||||
struct Conv2DInputBlockLoader {
|
||||
// Destination dimensions
|
||||
MLX_MTL_CONST int dst_fd = BM;
|
||||
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
||||
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
||||
|
||||
// Stride along block row within the block
|
||||
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>& params;
|
||||
|
||||
int weight_h;
|
||||
int weight_w;
|
||||
|
||||
int offsets_n[n_rows];
|
||||
int offsets_oh[n_rows];
|
||||
int offsets_ow[n_rows];
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DInputBlockLoader(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const constant MLXConvParams<2>& params_,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / n_vecs),
|
||||
bj(vec_size * (thread_idx % n_vecs)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bj),
|
||||
params(params_),
|
||||
weight_h(0),
|
||||
weight_w(0) {
|
||||
int out_n_pixels = params.oS[0] * params.oS[1];
|
||||
|
||||
for (int i = 0; i < n_rows; ++i) {
|
||||
int offset_nhw = tid.y * BM + bi + i * bstride;
|
||||
offsets_n[i] = offset_nhw / out_n_pixels;
|
||||
int hw = offset_nhw % out_n_pixels;
|
||||
offsets_oh[i] = hw / params.oS[1];
|
||||
offsets_ow[i] = hw % params.oS[1];
|
||||
}
|
||||
|
||||
(void)lid;
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
|
||||
int n = offsets_n[i];
|
||||
int oh = offsets_oh[i];
|
||||
int ow = offsets_ow[i];
|
||||
|
||||
int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
|
||||
int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
|
||||
|
||||
// Read from input if in bounds
|
||||
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
|
||||
const device T* curr_src = src + n * params.in_strides[0] +
|
||||
ih * params.in_strides[1] + iw * params.in_strides[2];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = curr_src[j];
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_w < params.wS[1]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = 0;
|
||||
|
||||
if (++weight_h < params.wS[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = 0;
|
||||
|
||||
src += BK;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int vec_size,
|
||||
int tgp_size,
|
||||
int tgp_padding = 0>
|
||||
struct Conv2DWeightBlockLoader {
|
||||
// Destination dimensions
|
||||
MLX_MTL_CONST int dst_fd = BN;
|
||||
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
||||
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
||||
|
||||
// Stride along block row within the block
|
||||
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
||||
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>& params;
|
||||
|
||||
int weight_h;
|
||||
int weight_w;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DWeightBlockLoader(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const constant MLXConvParams<2>& params_,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_.wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / n_vecs),
|
||||
bj(vec_size * (thread_idx % n_vecs)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj),
|
||||
params(params_),
|
||||
weight_h(0),
|
||||
weight_w(0) {
|
||||
(void)lid;
|
||||
(void)tid;
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
const device T* curr_src =
|
||||
src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < dst_fd; i += bstride) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_w < params.wS[1]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = 0;
|
||||
|
||||
if (++weight_h < params.wS[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = 0;
|
||||
|
||||
src += BK;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Transforms
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutT, typename InT>
|
||||
struct TransformNone {
|
||||
static METAL_FUNC OutT apply(InT x) {
|
||||
return static_cast<OutT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AccumHelper {
|
||||
typedef float accum_type;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int tgp_padding_a = 0,
|
||||
int tgp_padding_b = 0,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct Conv2DBlockMMA {
|
||||
// Warp tile size along M
|
||||
MLX_MTL_CONST int TM = BM / (WM * 8);
|
||||
// Warp tile size along N
|
||||
MLX_MTL_CONST int TN = BN / (WN * 8);
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TM_stride = 8 * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
MLX_MTL_CONST int TN_stride = 8 * WN;
|
||||
|
||||
// Leading dimensions of threadgroup A, B blocks
|
||||
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
|
||||
// Strides of A, B along reduction axis
|
||||
MLX_MTL_CONST short simd_stride_a =
|
||||
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
||||
MLX_MTL_CONST short simd_stride_b =
|
||||
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
||||
|
||||
// Jump between elements
|
||||
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
||||
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
||||
|
||||
// Offsets within threadgroup
|
||||
const int tm;
|
||||
const int tn;
|
||||
|
||||
// Simdgroup matrices
|
||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||
|
||||
short sm;
|
||||
short sn;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DBlockMMA(
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||
short qid = simd_lane_id / 4;
|
||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
}
|
||||
|
||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||
// Iterate over BK in blocks of 8
|
||||
#pragma clang loop unroll(full)
|
||||
for (short kk = 0; kk < BK; kk += 8) {
|
||||
short2 offset_a =
|
||||
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
||||
short2 offset_b =
|
||||
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
||||
|
||||
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
||||
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup A as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
||||
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
||||
As__ += simd_stride_a;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Load elements from threadgroup B as simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
||||
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
||||
Bs__ += simd_stride_b;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Multiply and accumulate into result simdgroup matrices
|
||||
#pragma clang loop unroll(full)
|
||||
for (short i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (short j = 0; j < TN; j++) {
|
||||
simdgroup_multiply_accumulate(
|
||||
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < TN; j++) {
|
||||
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
||||
}
|
||||
|
||||
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
struct Conv2DImplicitGEMMKernel {
|
||||
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||
MLX_MTL_CONST short tgp_mem_size_a =
|
||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||
MLX_MTL_CONST short tgp_mem_size_b =
|
||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||
|
||||
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
||||
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
||||
|
||||
using loader_a_t =
|
||||
Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
|
||||
using loader_b_t =
|
||||
Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
|
||||
using mma_t = Conv2DBlockMMA<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
tgp_padding_a,
|
||||
tgp_padding_b,
|
||||
AccumType,
|
||||
Epilogue>;
|
||||
|
||||
/* Main kernel function */
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const int K = params.wt_strides[0];
|
||||
const int N = params.O;
|
||||
|
||||
B += c_col * K;
|
||||
C += c_row * N + c_col;
|
||||
|
||||
// Prepare threadgroup memory for loading
|
||||
threadgroup T* As = tgp_memory;
|
||||
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, N);
|
||||
}
|
||||
};
|
@ -1,16 +1,102 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/conv.h"
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Slow and naive kernels
|
||||
/// Naive unfold with dilation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, int N>
|
||||
[[kernel]] void naive_unfold_Nd(
|
||||
const device T* in [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
const constant MLXConvParams<N>* params [[buffer(2)]],
|
||||
uint3 gid [[thread_position_in_grid]]) {
|
||||
|
||||
int filter_size = params->C;
|
||||
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
|
||||
|
||||
int out_pixels = 1;
|
||||
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
|
||||
|
||||
// Set out
|
||||
out += gid.z * filter_size + gid.y * (params->C);
|
||||
|
||||
// Corrdinates in input
|
||||
int is[N] = {0};
|
||||
|
||||
// gid.z: N oS (Batch and row in unfolded output)
|
||||
// gid.y: wS (Filter location to unfold input)
|
||||
// gid.x: C (channel)
|
||||
|
||||
int n = (gid.z) / out_pixels;
|
||||
int oS = (gid.z) % out_pixels;
|
||||
int wS = gid.y;
|
||||
|
||||
bool valid = n < params->N;
|
||||
|
||||
// Unroll dimensions
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
int os_ = (oS % params->oS[i]);
|
||||
int ws_ = (wS % params->wS[i]);
|
||||
|
||||
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
|
||||
|
||||
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
|
||||
int is_max = 1 + params->idil[i] * (params->iS[i] - 1);
|
||||
|
||||
valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);
|
||||
|
||||
is[i] = is_ / params->idil[i];
|
||||
|
||||
oS /= params->oS[i];
|
||||
wS /= params->wS[i];
|
||||
}
|
||||
|
||||
if(valid) {
|
||||
size_t in_offset = n * params->in_strides[0];
|
||||
|
||||
for(int i = 0; i < N; ++i) {
|
||||
in_offset += is[i] * params->in_strides[i + 1];
|
||||
}
|
||||
|
||||
out[gid.x] = in[in_offset + gid.x];
|
||||
} else {
|
||||
out[gid.x] = T(0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_naive_unfold_nd(name, itype, n) \
|
||||
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
|
||||
[[kernel]] void naive_unfold_Nd( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device itype* out [[buffer(1)]], \
|
||||
const constant MLXConvParams<n>* params [[buffer(2)]], \
|
||||
uint3 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_naive_unfold_nd_dims(name, itype) \
|
||||
instantiate_naive_unfold_nd(name, itype, 1) \
|
||||
instantiate_naive_unfold_nd(name, itype, 2) \
|
||||
instantiate_naive_unfold_nd(name, itype, 3)
|
||||
|
||||
instantiate_naive_unfold_nd_dims(float32, float);
|
||||
instantiate_naive_unfold_nd_dims(float16, half);
|
||||
instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Slow and naive conv2d kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
@ -58,8 +144,8 @@ template <typename T,
|
||||
|
||||
// Local in
|
||||
for(int m = 0; m < TM; m++) {
|
||||
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.dil[0];
|
||||
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.dil[1];
|
||||
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
|
||||
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
|
||||
|
||||
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
||||
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
|
||||
@ -116,59 +202,6 @@ instantiate_naive_conv_2d_blocks(float32, float);
|
||||
instantiate_naive_conv_2d_blocks(float16, half);
|
||||
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Implicit gemm kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* wt [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemm_kernel = Conv2DImplicitGEMMKernel<T, BM, BN, BK, WM, WN, /*transpose_a*/ false, /*transpose_b*/ true>;
|
||||
|
||||
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
|
||||
|
||||
gemm_kernel::run(
|
||||
in, wt, out,
|
||||
params, tgp_memory,
|
||||
tid, lid, simd_gid, simd_lid
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
const device itype* wt [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
instantiate_implicit_conv_2d(name, itype, 32, 32, 32, 2, 2) \
|
||||
instantiate_implicit_conv_2d(name, itype, 32, 32, 16, 2, 2) \
|
||||
instantiate_implicit_conv_2d(name, itype, 64, 64, 16, 2, 2)
|
||||
|
||||
instantiate_implicit_2d_blocks(float32, float);
|
||||
instantiate_implicit_2d_blocks(float16, half);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Winograd kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1,19 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <int NDIM>
|
||||
struct MLXConvParams {
|
||||
const int N; // Batch size
|
||||
const int C; // In channels
|
||||
const int O; // Out channels
|
||||
const int iS[NDIM]; // Input spatial dim
|
||||
const int wS[NDIM]; // Weight spatial dim
|
||||
const int oS[NDIM]; // Output spatial dim
|
||||
const int str[NDIM]; // Kernel strides
|
||||
const int pad[NDIM]; // Input padding
|
||||
const int dil[NDIM]; // Kernel dilation
|
||||
const size_t in_strides[NDIM + 2]; // In strides
|
||||
const size_t wt_strides[NDIM + 2]; // Wt strides
|
||||
const size_t out_strides[NDIM + 2]; // Out strides
|
||||
};
|
@ -13,6 +13,58 @@ using namespace metal;
|
||||
// Scatter kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
METAL_FUNC void scatter_1d_index_impl(
|
||||
const device T *updates [[buffer(1)]],
|
||||
device mlx_atomic<T> *out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]],
|
||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
uint out_idx = 0;
|
||||
for (int i = 0; i < NIDX; i++) {
|
||||
auto idx_val = offset_neg_idx(
|
||||
idx_buffers[i][gid.y], out_shape[i]);
|
||||
out_idx += idx_val * out_strides[i];
|
||||
}
|
||||
|
||||
if (!upd_col_contiguous) {
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
||||
} else {
|
||||
op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x);
|
||||
}
|
||||
}
|
||||
|
||||
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
[[kernel]] void scatter_1d_index( \
|
||||
const device T *updates [[buffer(1)]], \
|
||||
device mlx_atomic<T> *out [[buffer(2)]], \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 gid [[thread_position_in_grid]]) { \
|
||||
\
|
||||
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
|
||||
\
|
||||
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
|
||||
updates, \
|
||||
out, \
|
||||
out_shape, \
|
||||
out_strides, \
|
||||
upd_size, \
|
||||
upd_col_contiguous, \
|
||||
idx_buffers, \
|
||||
gid); \
|
||||
\
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
METAL_FUNC void scatter_impl(
|
||||
@ -46,10 +98,14 @@ METAL_FUNC void scatter_impl(
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
if (upd_size > 1) {
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
out_idx += out_offset;
|
||||
}
|
||||
|
||||
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
}
|
||||
|
||||
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
|
||||
@ -90,9 +146,11 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
axes, \
|
||||
idxs, \
|
||||
gid); \
|
||||
}
|
||||
}
|
||||
|
||||
#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
||||
#define make_scatter(n) \
|
||||
make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) \
|
||||
make_scatter_1d_index(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
||||
|
||||
make_scatter(0)
|
||||
make_scatter(1)
|
||||
@ -129,8 +187,21 @@ template [[host_name("scatter" name "_" #nidx)]] \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
||||
template [[host_name("scatter_1d_index" name "_" #nidx)]] \
|
||||
[[kernel]] void scatter_1d_index<src_t, idx_t, op_t, nidx>( \
|
||||
const device src_t *updates [[buffer(1)]], \
|
||||
device mlx_atomic<src_t> *out [[buffer(2)]], \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
|
||||
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
|
||||
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
|
||||
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
|
||||
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
|
11
mlx/backend/metal/kernels/steel/conv/conv.h
Normal file
11
mlx/backend/metal/kernels/steel/conv/conv.h
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loader.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
189
mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal
Normal file
189
mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal
Normal file
@ -0,0 +1,189 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
int N_CHANNELS = 0,
|
||||
bool SMALL_FILTER = false>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr bool transpose_a = false;
|
||||
constexpr bool transpose_b = true;
|
||||
constexpr short tgp_padding_a = 16 / sizeof(T);
|
||||
constexpr short tgp_padding_b = 16 / sizeof(T);
|
||||
|
||||
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
||||
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
||||
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
||||
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
||||
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
|
||||
// Input loader
|
||||
|
||||
using loader_a_t = typename metal::conditional_t<
|
||||
// Check for small channel specialization
|
||||
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
||||
|
||||
// Go to small channel specialization
|
||||
Conv2DInputBlockLoaderSmallChannels<
|
||||
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>,
|
||||
|
||||
// Else go to general loader
|
||||
typename metal::conditional_t<
|
||||
// Check if filter size is small enough
|
||||
SMALL_FILTER,
|
||||
|
||||
// Go to small filter specialization
|
||||
Conv2DInputBlockLoaderSmallFilter<
|
||||
T, BM, BN, BK, tgp_size, tgp_padding_a>,
|
||||
|
||||
// Else go to large filter generalization
|
||||
Conv2DInputBlockLoaderLargeFilter<
|
||||
T, BM, BN, BK, tgp_size, tgp_padding_a>
|
||||
>
|
||||
>;
|
||||
|
||||
|
||||
// Weight loader
|
||||
using loader_b_t = typename metal::conditional_t<
|
||||
// Check for small channel specialization
|
||||
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
||||
|
||||
// Go to small channel specialization
|
||||
Conv2DWeightBlockLoaderSmallChannels<
|
||||
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>,
|
||||
|
||||
// Else go to general loader
|
||||
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>
|
||||
>;
|
||||
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
shape_a_cols,
|
||||
shape_b_cols>;
|
||||
|
||||
threadgroup T As[tgp_mem_size_a];
|
||||
threadgroup T Bs[tgp_mem_size_b];
|
||||
|
||||
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
||||
|
||||
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int K = gemm_params->K;
|
||||
const int N = gemm_params->N;
|
||||
|
||||
B += c_col * K;
|
||||
C += c_row * N + c_col;
|
||||
|
||||
const int2 offsets_a(0, c_row);
|
||||
const int2 offsets_b(0, c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
int gemm_k_iterations = gemm_params->gemm_k_iterations;
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
short tgp_bm = min(BM, gemm_params->M - c_row);
|
||||
short tgp_bn = min(BN, gemm_params->N - c_col);
|
||||
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \
|
||||
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \
|
||||
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* C [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]], \
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
|
||||
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
||||
|
||||
instantiate_implicit_2d_blocks(float32, float);
|
||||
instantiate_implicit_2d_blocks(float16, half);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
@ -0,0 +1,209 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
template <typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr bool transpose_a = false;
|
||||
constexpr bool transpose_b = true;
|
||||
constexpr short tgp_padding_a = 16 / sizeof(T);
|
||||
constexpr short tgp_padding_b = 16 / sizeof(T);
|
||||
|
||||
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
||||
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
||||
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
||||
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
||||
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
|
||||
// Input loader
|
||||
using loader_a_t = Conv2DInputBlockLoaderGeneral<
|
||||
T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
||||
|
||||
// Weight loader
|
||||
using loader_b_t = Conv2DWeightBlockLoaderGeneral<
|
||||
T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
||||
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
shape_a_cols,
|
||||
shape_b_cols>;
|
||||
|
||||
threadgroup T As[tgp_mem_size_a];
|
||||
threadgroup T Bs[tgp_mem_size_b];
|
||||
|
||||
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
||||
|
||||
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tid_z = tid.z;
|
||||
|
||||
const int base_oh = tid_z / jump_params->f_out_jump_w;
|
||||
const int base_ow = tid_z % jump_params->f_out_jump_w;
|
||||
|
||||
const int base_wh = base_h[base_oh].weight_base;
|
||||
const int base_ww = base_w[base_ow].weight_base;
|
||||
|
||||
const int base_wh_size = base_h[base_oh].weight_size;
|
||||
const int base_ww_size = base_w[base_ow].weight_size;
|
||||
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int K = gemm_params->K;
|
||||
|
||||
B += c_col * K;
|
||||
|
||||
const int4 offsets_a(0, c_row, base_oh, base_ow);
|
||||
const int2 offsets_b(0, c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(A, As, offsets_a, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(B, Bs, offsets_b, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
int gemm_k_iterations = base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
{
|
||||
// Adjust for simdgroup and thread locatio
|
||||
int offset_m = c_row + mma_op.sm + mma_op.tm;
|
||||
int offset_n = c_col + mma_op.sn + mma_op.tn;
|
||||
C += offset_n;
|
||||
|
||||
if (offset_n >= gemm_params->N)
|
||||
return;
|
||||
|
||||
short diff = gemm_params->N - offset_n;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < mma_t::TM; i++) {
|
||||
|
||||
int cm = offset_m + i * mma_t::TM_stride;
|
||||
|
||||
int n = cm / jump_params->adj_out_hw;
|
||||
int hw = cm % jump_params->adj_out_hw;
|
||||
int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
||||
int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
||||
|
||||
if(n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
||||
|
||||
int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < mma_t::TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = mma_op.results[i * mma_t::TN + j].thread_elements();
|
||||
int offset = offset_cm + (j * mma_t::TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * mma_t::TN_stride < diff) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * mma_t::TN_stride + 1 < diff) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* C [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]], \
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
||||
|
||||
instantiate_implicit_2d_blocks(float32, float);
|
||||
instantiate_implicit_2d_blocks(float16, half);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
|
6
mlx/backend/metal/kernels/steel/conv/loader.h
Normal file
6
mlx/backend/metal/kernels/steel/conv/loader.h
Normal file
@ -0,0 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h"
|
449
mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h
Normal file
449
mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h
Normal file
@ -0,0 +1,449 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DInputBlockLoaderLargeFilter {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BM;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
const constant ImplicitGemmConv2DParams* gemm_params;
|
||||
|
||||
short weight_h;
|
||||
short weight_w;
|
||||
|
||||
const device T* src[n_rows];
|
||||
|
||||
int read_n[n_rows];
|
||||
int read_ih[n_rows];
|
||||
int read_iw[n_rows];
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DInputBlockLoaderLargeFilter(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int2 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
params(params_),
|
||||
gemm_params(gemm_params_),
|
||||
weight_h(0),
|
||||
weight_w(0) {
|
||||
int out_n_pixels = params->oS[0] * params->oS[1];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
int offset_nhw = offsets.y + bi + i * TROWS;
|
||||
int n = offset_nhw / out_n_pixels;
|
||||
int hw = offset_nhw % out_n_pixels;
|
||||
int oh = hw / params->oS[1];
|
||||
int ow = hw % params->oS[1];
|
||||
|
||||
int ih = oh * params->str[0] - params->pad[0];
|
||||
int iw = ow * params->str[1] - params->pad[1];
|
||||
|
||||
read_n[i] = n;
|
||||
read_ih[i] = ih;
|
||||
read_iw[i] = iw;
|
||||
|
||||
// Adjust for flip
|
||||
if (params->flip) {
|
||||
ih += (params->wS[0] - 1) * params->kdil[0];
|
||||
iw += (params->wS[1] - 1) * params->kdil[1];
|
||||
}
|
||||
|
||||
// Read from input if in bounds
|
||||
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
|
||||
iw * params->in_strides[2] + bj;
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
||||
// Find bounds
|
||||
int n = read_n[i];
|
||||
int ih = read_ih[i] + weight_h * params->kdil[0];
|
||||
int iw = read_iw[i] + weight_w * params->kdil[1];
|
||||
|
||||
// Read from input if in bounds
|
||||
if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
|
||||
(iw >= 0 && iw < params->iS[1])) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = src[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_w < params->wS[1]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += gemm_params->inp_jump_w;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = 0;
|
||||
|
||||
if (++weight_h < params->wS[0]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += gemm_params->inp_jump_h;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = 0;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += gemm_params->inp_jump_c;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DInputBlockLoaderSmallFilter {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BM;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
using mask_t = short;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
const constant ImplicitGemmConv2DParams* gemm_params;
|
||||
|
||||
short weight_h;
|
||||
short weight_w;
|
||||
|
||||
const device T* src[n_rows];
|
||||
|
||||
mask_t mask_h[n_rows];
|
||||
mask_t mask_w[n_rows];
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DInputBlockLoaderSmallFilter(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int2 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
params(params_),
|
||||
gemm_params(gemm_params_),
|
||||
weight_h(0),
|
||||
weight_w(0) {
|
||||
int out_n_pixels = params->oS[0] * params->oS[1];
|
||||
|
||||
int read_n[n_rows];
|
||||
int read_ih[n_rows];
|
||||
int read_iw[n_rows];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
int offset_nhw = offsets.y + bi + i * TROWS;
|
||||
int n = offset_nhw / out_n_pixels;
|
||||
int hw = offset_nhw % out_n_pixels;
|
||||
int oh = hw / params->oS[1];
|
||||
int ow = hw % params->oS[1];
|
||||
|
||||
int ih = oh * params->str[0] - params->pad[0];
|
||||
int iw = ow * params->str[1] - params->pad[1];
|
||||
|
||||
read_n[i] = n;
|
||||
read_ih[i] = ih;
|
||||
read_iw[i] = iw;
|
||||
|
||||
// Adjust for flip
|
||||
if (params->flip) {
|
||||
ih += (params->wS[0] - 1) * params->kdil[0];
|
||||
iw += (params->wS[1] - 1) * params->kdil[1];
|
||||
}
|
||||
|
||||
// Read from input if in bounds
|
||||
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
|
||||
iw * params->in_strides[2] + bj;
|
||||
}
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
mask_h[i] = 0;
|
||||
mask_w[i] = 0;
|
||||
}
|
||||
|
||||
for (short kh = 0; kh < params->wS[0]; kh++) {
|
||||
short flip_h = params->flip ? params->wS[0] - kh - 1 : kh;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
int n = read_n[i];
|
||||
int ih = read_ih[i] + flip_h * params->kdil[0];
|
||||
|
||||
bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0];
|
||||
|
||||
mask_h[i] |= (in_bounds << kh);
|
||||
}
|
||||
}
|
||||
|
||||
for (short kw = 0; kw < params->wS[1]; kw++) {
|
||||
short flip_w = params->flip ? params->wS[1] - kw - 1 : kw;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
int iw = read_iw[i] + flip_w * params->kdil[1];
|
||||
|
||||
bool in_bounds = iw >= 0 && iw < params->iS[1];
|
||||
|
||||
mask_w[i] |= (in_bounds << kw);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
mask_t h_mask = mask_t(1) << weight_h;
|
||||
mask_t w_mask = mask_t(1) << weight_w;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
||||
// Read from input if in bounds
|
||||
if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = src[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_w < params->wS[1]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += gemm_params->inp_jump_w;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = 0;
|
||||
|
||||
if (++weight_h < params->wS[0]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += gemm_params->inp_jump_h;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = 0;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += gemm_params->inp_jump_c;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DWeightBlockLoader {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BN;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size =
|
||||
(BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
|
||||
int weight_hw;
|
||||
|
||||
const int read_n;
|
||||
const bool do_read;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DWeightBlockLoader(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int2 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj),
|
||||
params(params_),
|
||||
weight_hw(0),
|
||||
read_n(offsets.y + bi),
|
||||
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
if (BN != 8 || do_read) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BN; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < BN; i += TROWS) {
|
||||
if ((read_n + i) < params->O) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = src[i * src_ld + j];
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
if (++weight_hw < (params->wS[1] * params->wS[0])) {
|
||||
src += params->wt_strides[2];
|
||||
return;
|
||||
}
|
||||
|
||||
weight_hw = 0;
|
||||
|
||||
src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
319
mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h
Normal file
319
mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h
Normal file
@ -0,0 +1,319 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <short n_channels_>
|
||||
struct ChannelHelper {
|
||||
STEEL_CONST short n_channels = n_channels_;
|
||||
STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8;
|
||||
STEEL_CONST short excess = vec_size - n_channels_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ChannelHelper<1> {
|
||||
STEEL_CONST short n_channels = 1;
|
||||
STEEL_CONST short vec_size = 1;
|
||||
STEEL_CONST short excess = 0;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ChannelHelper<2> {
|
||||
STEEL_CONST short n_channels = 2;
|
||||
STEEL_CONST short vec_size = 2;
|
||||
STEEL_CONST short excess = 0;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ChannelHelper<3> {
|
||||
STEEL_CONST short n_channels = 3;
|
||||
STEEL_CONST short vec_size = 4;
|
||||
STEEL_CONST short excess = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ChannelHelper<4> {
|
||||
STEEL_CONST short n_channels = 4;
|
||||
STEEL_CONST short vec_size = 4;
|
||||
STEEL_CONST short excess = 0;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short n_channels,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DInputBlockLoaderSmallChannels {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BM;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
const constant ImplicitGemmConv2DParams* gemm_params;
|
||||
|
||||
short weight_hw;
|
||||
|
||||
const device T* src[n_rows];
|
||||
|
||||
int read_n[n_rows];
|
||||
int read_ih[n_rows];
|
||||
int read_iw[n_rows];
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DInputBlockLoaderSmallChannels(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int2 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
params(params_),
|
||||
gemm_params(gemm_params_),
|
||||
weight_hw(thread_idx % TCOLS) {
|
||||
int out_n_pixels = params->oS[0] * params->oS[1];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
int offset_nhw = offsets.y + bi + i * TROWS;
|
||||
int n = offset_nhw / out_n_pixels;
|
||||
int hw = offset_nhw % out_n_pixels;
|
||||
int oh = hw / params->oS[1];
|
||||
int ow = hw % params->oS[1];
|
||||
|
||||
int ih = oh * params->str[0] - params->pad[0];
|
||||
int iw = ow * params->str[1] - params->pad[1];
|
||||
|
||||
// Read from input if in bounds
|
||||
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
|
||||
iw * params->in_strides[2];
|
||||
|
||||
read_n[i] = n;
|
||||
read_ih[i] = ih;
|
||||
read_iw[i] = iw;
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
if (weight_hw >= params->wS[1] * params->wS[0]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int wh = (weight_hw / params->wS[1]);
|
||||
int ww = (weight_hw % params->wS[1]);
|
||||
|
||||
int flip_h = params->flip ? params->wS[0] - wh - 1 : wh;
|
||||
int flip_w = params->flip ? params->wS[1] - ww - 1 : ww;
|
||||
|
||||
int weight_h = flip_h * params->kdil[0];
|
||||
int weight_w = flip_w * params->kdil[1];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
||||
// Find bounds
|
||||
int n = read_n[i];
|
||||
int ih = read_ih[i] + weight_h;
|
||||
int iw = read_iw[i] + weight_w;
|
||||
|
||||
// Read from input if in bounds
|
||||
if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
|
||||
(iw >= 0 && iw < params->iS[1])) {
|
||||
const device T* curr_src = src[i] + weight_h * params->in_strides[1] +
|
||||
weight_w * params->in_strides[2];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < n_channels; ++j) {
|
||||
dst[is * dst_ld + j] = curr_src[j];
|
||||
}
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = n_channels; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
weight_hw += TCOLS;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short n_channels,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DWeightBlockLoaderSmallChannels {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BN;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
|
||||
int weight_hw;
|
||||
|
||||
const int read_n;
|
||||
const bool do_read;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DWeightBlockLoaderSmallChannels(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int2 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant ImplicitGemmConv2DParams* gemm_params_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld),
|
||||
params(params_),
|
||||
weight_hw(thread_idx % TCOLS),
|
||||
read_n(offsets.y + bi),
|
||||
do_read(read_n + BN <= gemm_params_->N) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
if (bi >= BROWS || bj >= BCOLS)
|
||||
return;
|
||||
|
||||
if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const device T* curr_src = src + weight_hw * params->wt_strides[2];
|
||||
|
||||
if (BN != 8 || do_read) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < n_channels; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = n_channels; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
if (((read_n + i) < params->O)) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < n_channels; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = n_channels; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
weight_hw += TCOLS;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
288
mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h
Normal file
288
mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h
Normal file
@ -0,0 +1,288 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DInputBlockLoaderGeneral {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BM;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
const constant Conv2DGeneralJumpParams* jump_params;
|
||||
|
||||
const short base_wh;
|
||||
const short base_ww;
|
||||
|
||||
short weight_h;
|
||||
short weight_w;
|
||||
|
||||
const device T* src[n_rows];
|
||||
|
||||
int read_n[n_rows];
|
||||
int read_ih[n_rows];
|
||||
int read_iw[n_rows];
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DInputBlockLoaderGeneral(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int4 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant Conv2DGeneralJumpParams* jump_params_,
|
||||
const short base_wh_,
|
||||
const short base_ww_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
params(params_),
|
||||
jump_params(jump_params_),
|
||||
base_wh(base_wh_),
|
||||
base_ww(base_ww_),
|
||||
weight_h(base_wh_),
|
||||
weight_w(base_ww_) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; ++i) {
|
||||
int offset_nhw = offsets.y + bi + i * TROWS;
|
||||
int n = offset_nhw / jump_params->adj_out_hw;
|
||||
int hw = offset_nhw % jump_params->adj_out_hw;
|
||||
int oh =
|
||||
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z;
|
||||
int ow =
|
||||
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w;
|
||||
|
||||
int ih = oh * params->str[0] - params->pad[0];
|
||||
int iw = ow * params->str[1] - params->pad[1];
|
||||
|
||||
read_n[i] = n;
|
||||
read_ih[i] = ih;
|
||||
read_iw[i] = iw;
|
||||
|
||||
// Read from input if in bounds
|
||||
src[i] = src_ + n * params->in_strides[0] + bj;
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
||||
// Find bounds
|
||||
int n = read_n[i];
|
||||
|
||||
int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
|
||||
int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
|
||||
|
||||
int ih_dil = read_ih[i] + h_flip * params->kdil[0];
|
||||
int iw_dil = read_iw[i] + w_flip * params->kdil[1];
|
||||
|
||||
int ih = ih_dil / params->idil[0];
|
||||
int iw = iw_dil / params->idil[1];
|
||||
|
||||
size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
|
||||
|
||||
// Read from input if in bounds
|
||||
if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
|
||||
(iw_dil >= 0 && iw < params->iS[1])) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = (src[i])[offset + j];
|
||||
}
|
||||
}
|
||||
|
||||
// Zero pad otherwise
|
||||
else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; ++j) {
|
||||
dst[is * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
weight_w += jump_params->f_wgt_jump_w;
|
||||
if (weight_w < params->wS[1]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = base_ww;
|
||||
|
||||
weight_h += jump_params->f_wgt_jump_h;
|
||||
if (weight_h < params->wS[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = base_wh;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
src[i] += BK;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BM,
|
||||
short BN,
|
||||
short BK,
|
||||
short tgp_size,
|
||||
short tgp_padding = 0>
|
||||
struct Conv2DWeightBlockLoaderGeneral {
|
||||
// Destination dimensions
|
||||
STEEL_CONST short BROWS = BN;
|
||||
STEEL_CONST short BCOLS = BK;
|
||||
|
||||
// Read dimensions
|
||||
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
||||
STEEL_CONST short vec_size =
|
||||
(BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
|
||||
|
||||
// Thread read shape
|
||||
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
||||
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
||||
|
||||
// Rows / strided reads within the block
|
||||
STEEL_CONST short n_rows = BROWS / TROWS;
|
||||
|
||||
// Leading dimension for src
|
||||
const int src_ld;
|
||||
|
||||
// Thread location indices
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
|
||||
// threadgroup and device memory
|
||||
threadgroup T* dst;
|
||||
const device T* src;
|
||||
|
||||
const constant MLXConvParams<2>* params;
|
||||
const constant Conv2DGeneralJumpParams* jump_params;
|
||||
|
||||
const short base_wh;
|
||||
const short base_ww;
|
||||
|
||||
short weight_h;
|
||||
short weight_w;
|
||||
|
||||
const int start_row;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC Conv2DWeightBlockLoaderGeneral(
|
||||
const device T* src_,
|
||||
threadgroup T* dst_,
|
||||
const int2 offsets,
|
||||
const constant MLXConvParams<2>* params_,
|
||||
const constant Conv2DGeneralJumpParams* jump_params_,
|
||||
const short base_wh_,
|
||||
const short base_ww_,
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(params_->wt_strides[0]),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(thread_idx / TCOLS),
|
||||
bj(vec_size * (thread_idx % TCOLS)),
|
||||
dst(dst_ + bi * dst_ld + bj),
|
||||
src(src_ + bi * src_ld + bj),
|
||||
params(params_),
|
||||
jump_params(jump_params_),
|
||||
base_wh(base_wh_),
|
||||
base_ww(base_ww_),
|
||||
weight_h(base_wh_),
|
||||
weight_w(base_ww_),
|
||||
start_row(offsets.y + bi) {}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void load_unsafe() const {
|
||||
const device T* curr_src = src + weight_h * params->wt_strides[1] +
|
||||
weight_w * params->wt_strides[2];
|
||||
|
||||
if ((start_row + BN <= params->O)) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BN; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (short i = 0; i < BN; i += TROWS) {
|
||||
if ((start_row + i) < params->O) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
||||
}
|
||||
} else {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
weight_w += jump_params->f_wgt_jump_w;
|
||||
if (weight_w < params->wS[1]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_w = base_ww;
|
||||
|
||||
weight_h += jump_params->f_wgt_jump_h;
|
||||
if (weight_h < params->wS[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
weight_h = base_wh;
|
||||
|
||||
src += BK;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
62
mlx/backend/metal/kernels/steel/conv/params.h
Normal file
62
mlx/backend/metal/kernels/steel/conv/params.h
Normal file
@ -0,0 +1,62 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <int NDIM>
|
||||
struct MLXConvParams {
|
||||
const int N; // Batch size
|
||||
const int C; // In channels
|
||||
const int O; // Out channels
|
||||
const int iS[NDIM]; // Input spatial dim
|
||||
const int wS[NDIM]; // Weight spatial dim
|
||||
const int oS[NDIM]; // Output spatial dim
|
||||
const int str[NDIM]; // Kernel strides
|
||||
const int pad[NDIM]; // Input padding
|
||||
const int kdil[NDIM]; // Kernel dilation
|
||||
const int idil[NDIM]; // Input dilation
|
||||
const size_t in_strides[NDIM + 2]; // In strides
|
||||
const size_t wt_strides[NDIM + 2]; // Wt strides
|
||||
const size_t out_strides[NDIM + 2]; // Out strides
|
||||
const int groups; // Input channel groups
|
||||
const bool flip;
|
||||
};
|
||||
|
||||
namespace mlx {
|
||||
namespace steel {
|
||||
|
||||
struct ImplicitGemmConv2DParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int gemm_k_iterations;
|
||||
|
||||
const int inp_jump_w;
|
||||
const int inp_jump_h;
|
||||
const int inp_jump_c;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
const int swizzle_log;
|
||||
};
|
||||
|
||||
struct Conv2DGeneralJumpParams {
|
||||
const int f_wgt_jump_h;
|
||||
const int f_wgt_jump_w;
|
||||
|
||||
const int f_out_jump_h;
|
||||
const int f_out_jump_w;
|
||||
|
||||
const int adj_out_h;
|
||||
const int adj_out_w;
|
||||
const int adj_out_hw;
|
||||
const int adj_implicit_m;
|
||||
};
|
||||
|
||||
struct Conv2DGeneralBaseInfo {
|
||||
int weight_base;
|
||||
int weight_size;
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
|
@ -2,9 +2,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MMA helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -167,6 +173,9 @@ struct BlockMMA {
|
||||
C += (sm + tm) * ldc + (tn + sn);
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
@ -236,6 +245,9 @@ struct BlockMMA {
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
|
@ -1,5 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
@ -3,7 +3,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include "mlx/backend/metal/kernels/steel/host.h"
|
||||
|
||||
#define STEEL_CONST static constant constexpr const
|
||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
10
mlx/backend/metal/kernels/ternary.h
Normal file
10
mlx/backend/metal/kernels/ternary.h
Normal file
@ -0,0 +1,10 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
T operator()(bool condition, T x, T y) {
|
||||
return condition ? x : y;
|
||||
}
|
||||
};
|
201
mlx/backend/metal/kernels/ternary.metal
Normal file
201
mlx/backend/metal/kernels/ternary.metal
Normal file
@ -0,0 +1,201 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_v(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
d[index] = Op()(a[index], b[index], c[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g_nd1(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t& a_strides,
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_strides);
|
||||
auto b_idx = elem_to_loc_1(index, b_strides);
|
||||
auto c_idx = elem_to_loc_1(index, c_strides);
|
||||
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g_nd2(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const size_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
auto c_idx = elem_to_loc_2(index, c_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g_nd3(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const size_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
auto c_idx = elem_to_loc_3(index, c_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op, int DIM>
|
||||
[[kernel]] void ternary_op_g_nd(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
constant const size_t c_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void ternary_op_g(
|
||||
device const bool* a,
|
||||
device const T* b,
|
||||
device const T* c,
|
||||
device T* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||
}
|
||||
|
||||
#define instantiate_ternary_v(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void ternary_op_v<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
|
||||
#define instantiate_ternary_g(name, type, op) \
|
||||
template [[host_name(name)]] \
|
||||
[[kernel]] void ternary_op_g<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const size_t* c_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
|
||||
#define instantiate_ternary_g_dim(name, type, op, dims) \
|
||||
template [[host_name(name "_" #dims)]] \
|
||||
[[kernel]] void ternary_op_g_nd<type, op, dims>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
constant const size_t c_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
|
||||
#define instantiate_ternary_g_nd(name, type, op) \
|
||||
template [[host_name(name "_1")]] \
|
||||
[[kernel]] void ternary_op_g_nd1<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t& a_strides, \
|
||||
constant const size_t& b_strides, \
|
||||
constant const size_t& c_strides, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] \
|
||||
[[kernel]] void ternary_op_g_nd2<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
constant const size_t c_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] \
|
||||
[[kernel]] void ternary_op_g_nd3<type, op>( \
|
||||
device const bool* a, \
|
||||
device const type* b, \
|
||||
device const type* c, \
|
||||
device type* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
constant const size_t c_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_ternary_g_dim(name, type, op, 4) \
|
||||
instantiate_ternary_g_dim(name, type, op, 5) \
|
||||
|
||||
#define instantiate_ternary_all(name, tname, type, op) \
|
||||
instantiate_ternary_v("v" #name #tname, type, op) \
|
||||
instantiate_ternary_g("g" #name #tname, type, op) \
|
||||
instantiate_ternary_g_nd("g" #name #tname, type, op) \
|
||||
|
||||
#define instantiate_ternary_types(name, op) \
|
||||
instantiate_ternary_all(name, bool_, bool, op) \
|
||||
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
||||
instantiate_ternary_all(name, uint16, uint16_t, op) \
|
||||
instantiate_ternary_all(name, uint32, uint32_t, op) \
|
||||
instantiate_ternary_all(name, uint64, uint64_t, op) \
|
||||
instantiate_ternary_all(name, int8, int8_t, op) \
|
||||
instantiate_ternary_all(name, int16, int16_t, op) \
|
||||
instantiate_ternary_all(name, int32, int32_t, op) \
|
||||
instantiate_ternary_all(name, int64, int64_t, op) \
|
||||
instantiate_ternary_all(name, float16, half, op) \
|
||||
instantiate_ternary_all(name, float32, float, op) \
|
||||
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
|
||||
instantiate_ternary_all(name, complex64, complex64_t, op) \
|
||||
|
||||
instantiate_ternary_types(select, Select)
|
@ -9,6 +9,10 @@
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
namespace {
|
||||
constant float inf = metal::numeric_limits<float>::infinity();
|
||||
}
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
|
@ -91,6 +91,30 @@ inline size_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint3 elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int shape[NDIM],
|
||||
constant const size_t a_strides[NDIM],
|
||||
constant const size_t b_strides[NDIM],
|
||||
constant const size_t c_strides[NDIM]) {
|
||||
uint3 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
|
||||
for (int d = NDIM - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
@ -150,6 +174,30 @@ inline size_t elem_to_loc(
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint3 elem_to_loc_3_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
int ndim) {
|
||||
uint3 loc = {
|
||||
static_cast<uint>(
|
||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
|
||||
static_cast<uint>(
|
||||
elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * a_strides[d];
|
||||
loc.y += l * b_strides[d];
|
||||
loc.z += l * c_strides[d];
|
||||
elem.z /= shape[d];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
inline uint2 elem_to_loc_2_nd(
|
||||
uint3 elem,
|
||||
constant const int* shape,
|
||||
|
@ -8,7 +8,7 @@
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/host.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <future>
|
||||
@ -10,6 +10,10 @@
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
int max_ops_per_buffer() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@ -11,16 +11,52 @@
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
constexpr bool is_available() {
|
||||
#ifdef _METAL_
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
bool is_available();
|
||||
|
||||
bool cache_enabled(void);
|
||||
void set_cache_enabled(bool enabled);
|
||||
/* Get the actively used memory in bytes.
|
||||
*
|
||||
* Note, this will not always match memory use reported by the system because
|
||||
* it does not include cached memory buffers.
|
||||
* */
|
||||
size_t get_active_memory();
|
||||
|
||||
/* Get the peak amount of used memory in bytes.
|
||||
*
|
||||
* The maximum memory used is recorded from the beginning of the program
|
||||
* execution.
|
||||
* */
|
||||
size_t get_peak_memory();
|
||||
|
||||
/* Get the cache size in bytes.
|
||||
*
|
||||
* The cache includes memory not currently used that has not been returned
|
||||
* to the system allocator.
|
||||
* */
|
||||
size_t get_cache_memory();
|
||||
|
||||
/* Set the memory limit.
|
||||
* Calls to malloc will wait on scheduled tasks if the limit is exceeded. If
|
||||
* there are no more scheduled tasks an error will be raised if relaxed
|
||||
* is false or memory will be allocated (including the potential for
|
||||
* swap) if relaxed is true.
|
||||
*
|
||||
* The memory limit defaults to 1.5 times the maximum recommended working set
|
||||
* size reported by the device.
|
||||
*
|
||||
* Returns the previous memory limit.
|
||||
* */
|
||||
size_t set_memory_limit(size_t limit, bool relaxed = true);
|
||||
|
||||
/* Set the free cache limit.
|
||||
* If using more than the given limit, free memory will be reclaimed
|
||||
* from the cache on the next allocation. To disable the cache,
|
||||
* set the limit to 0.
|
||||
*
|
||||
* The cache limit defaults to the memory limit.
|
||||
*
|
||||
* Returns the previous cache limit.
|
||||
* */
|
||||
size_t set_cache_limit(size_t limit);
|
||||
|
||||
void new_stream(Stream stream);
|
||||
std::shared_ptr<void> new_scoped_memory_pool();
|
||||
|
@ -1,11 +1,11 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/ternary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
@ -43,24 +43,25 @@ void binary_op(
|
||||
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case General:
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
if (bopt == BinaryOpType::General &&
|
||||
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
@ -80,7 +81,7 @@ void binary_op(
|
||||
set_array_buffer(compute_encoder, outputs[0], 2);
|
||||
set_array_buffer(compute_encoder, outputs[1], 3);
|
||||
|
||||
if (bopt == General) {
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
@ -141,24 +142,25 @@ void binary_op(
|
||||
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case ScalarVector:
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case VectorScalar:
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case VectorVector:
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case General:
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
if (bopt == BinaryOpType::General &&
|
||||
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
@ -173,7 +175,7 @@ void binary_op(
|
||||
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
if (bopt == General) {
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
||||
@ -202,7 +204,94 @@ void binary_op(
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = bopt == General ? out.size() : out.data_size();
|
||||
size_t nthreads =
|
||||
bopt == BinaryOpType::General ? out.size() : out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void ternary_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 3);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto& c = inputs[2];
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
|
||||
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_c = strides[2];
|
||||
auto& strides_out = strides[3];
|
||||
|
||||
std::ostringstream kname;
|
||||
if (topt == TernaryOpType::General) {
|
||||
kname << "g";
|
||||
kname << op << type_to_name(b);
|
||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
} else {
|
||||
kname << "v";
|
||||
kname << op << type_to_name(b);
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
if (topt == TernaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
}
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
@ -430,8 +519,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
simd_size * (sizeof(uint32_t) + in.itemsize()), 0);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
@ -621,6 +708,10 @@ void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
}
|
||||
|
||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
ternary_op(inputs, out, "select");
|
||||
}
|
||||
|
||||
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "neg");
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
@ -6,6 +6,10 @@
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
void new_stream(Stream) {}
|
||||
std::shared_ptr<void> new_scoped_memory_pool() {
|
||||
return nullptr;
|
||||
@ -19,10 +23,21 @@ std::function<void()> make_task(
|
||||
"[metal::make_task] Cannot make GPU task without metal backend");
|
||||
}
|
||||
|
||||
// No cache for CPU only
|
||||
bool cache_enabled(void) {
|
||||
return false;
|
||||
// No-ops when Metal is not available.
|
||||
size_t get_active_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t get_cache_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t set_memory_limit(size_t, bool) {
|
||||
return 0;
|
||||
}
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
void set_cache_enabled(bool) {}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@ -80,6 +80,7 @@ NO_GPU(Reshape)
|
||||
NO_GPU(Round)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(Sigmoid)
|
||||
NO_GPU(Sign)
|
||||
NO_GPU(Sin)
|
||||
|
@ -47,6 +47,10 @@ bool is_binary(const Primitive& p) {
|
||||
typeid(p) == typeid(Subtract));
|
||||
}
|
||||
|
||||
bool is_ternary(const Primitive& p) {
|
||||
return typeid(p) == typeid(Select);
|
||||
}
|
||||
|
||||
bool is_broadcast(const Primitive& p) {
|
||||
return typeid(p) == typeid(Broadcast);
|
||||
}
|
||||
@ -60,14 +64,16 @@ bool is_reduction(const Primitive& p) {
|
||||
}
|
||||
|
||||
bool is_fusable(const Primitive& p) {
|
||||
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
|
||||
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||
|
||||
is_noop(p);
|
||||
}
|
||||
|
||||
bool allows_shapeless(const Primitive& p) {
|
||||
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
|
||||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
|
||||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
|
||||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition);
|
||||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
|
||||
typeid(p) == typeid(Select);
|
||||
}
|
||||
|
||||
Compiled::Compiled(
|
||||
|
@ -46,10 +46,6 @@ struct Dtype {
|
||||
};
|
||||
};
|
||||
|
||||
inline bool is_available(const Dtype& dtype) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
||||
|
||||
static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
|
||||
|
311
mlx/ops.cpp
311
mlx/ops.cpp
@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
@ -73,10 +74,24 @@ array arange(
|
||||
if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) {
|
||||
throw std::invalid_argument("[arange] Cannot compute length.");
|
||||
}
|
||||
double real_size = std::ceil((stop - start) / step);
|
||||
if (std::isnan(real_size)) {
|
||||
|
||||
if (std::isinf(start) || std::isinf(stop)) {
|
||||
throw std::invalid_argument("[arange] Cannot compute length.");
|
||||
}
|
||||
|
||||
// Check if start and stop specify a valid range because if not, we have to
|
||||
// return an empty array
|
||||
if (std::isinf(step) &&
|
||||
(step > 0 && start < stop || step < 0 && start > stop)) {
|
||||
return array({start}, dtype);
|
||||
}
|
||||
|
||||
double real_size = std::ceil((stop - start) / step);
|
||||
|
||||
if (real_size > INT_MAX) {
|
||||
throw std::invalid_argument("[arange] Maximum size exceeded.");
|
||||
}
|
||||
|
||||
int size = std::max(static_cast<int>(real_size), 0);
|
||||
return array(
|
||||
{size},
|
||||
@ -1149,13 +1164,20 @@ array isneginf(const array& a, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
|
||||
array where(
|
||||
const array& condition,
|
||||
const array& x,
|
||||
const array& y,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// TODO, fix this to handle the NaN case when x has infs
|
||||
auto mask = astype(condition, bool_, s);
|
||||
return add(multiply(x, mask, s), multiply(y, logical_not(mask, s), s), s);
|
||||
auto condition = astype(a, bool_, s);
|
||||
Dtype out_dtype = promote_types(b.dtype(), c.dtype());
|
||||
auto inputs = broadcast_arrays(
|
||||
{condition, astype(b, out_dtype, s), astype(c, out_dtype, s)}, s);
|
||||
|
||||
return array(
|
||||
inputs[0].shape(),
|
||||
out_dtype,
|
||||
std::make_unique<Select>(to_stream(s)),
|
||||
inputs);
|
||||
}
|
||||
|
||||
array allclose(
|
||||
@ -1678,7 +1700,7 @@ array argpartition(
|
||||
int kth_ = kth < 0 ? kth + a.shape(axis) : kth;
|
||||
if (kth_ < 0 || kth_ >= a.shape(axis_)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[argpartition] Received invalid kth " << kth << "along axis "
|
||||
msg << "[argpartition] Received invalid kth " << kth << " along axis "
|
||||
<< axis << " for array with shape: " << a.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@ -1699,24 +1721,28 @@ array topk(const array& a, int k, StreamOrDevice s /* = {}*/) {
|
||||
array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {
|
||||
// Check for valid axis
|
||||
int axis_ = axis < 0 ? axis + a.ndim() : axis;
|
||||
int kth_ = k < 0 ? k + a.shape(axis) : k;
|
||||
if (axis_ < 0 || axis_ >= static_cast<int>(a.ndim())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[topk] Received invalid axis " << axis << " for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (kth_ < 0 || kth_ >= a.shape(axis_)) {
|
||||
if (k < 0 || k > a.shape(axis_)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[topk] Received invalid k " << k << "along axis " << axis
|
||||
msg << "[topk] Received invalid k=" << k << " along axis " << axis
|
||||
<< " for array with shape: " << a.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
array a_partitioned = partition(a, kth_, axis_, s);
|
||||
// Return early if the whole input was requested.
|
||||
if (k == a.shape(axis_)) {
|
||||
return a;
|
||||
}
|
||||
|
||||
array a_partitioned = partition(a, -k, axis_, s);
|
||||
std::vector<int> slice_starts(a.ndim(), 0);
|
||||
std::vector<int> slice_ends = a.shape();
|
||||
slice_starts[axis_] = kth_;
|
||||
slice_starts[axis_] = a.shape(axis_) - k;
|
||||
return slice(a_partitioned, slice_starts, slice_ends, s);
|
||||
}
|
||||
|
||||
@ -1733,7 +1759,11 @@ array logsumexp(
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
auto maxval = stop_gradient(max(a, axes, true, s));
|
||||
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
|
||||
return add(out, reshape(maxval, out.shape(), s), s);
|
||||
out = add(out, reshape(maxval, out.shape(), s), s);
|
||||
if (!keepdims) {
|
||||
maxval = squeeze(maxval, axes, s);
|
||||
}
|
||||
return where(isinf(maxval, s), maxval, out, s);
|
||||
}
|
||||
|
||||
array logsumexp(
|
||||
@ -2670,33 +2700,78 @@ array cummin(
|
||||
namespace {
|
||||
|
||||
// Conv helpers
|
||||
inline int conv_out_axis_size(
|
||||
int in_dim,
|
||||
int wt_dim,
|
||||
int stride,
|
||||
int padding,
|
||||
int dilation) {
|
||||
int ker = dilation * (wt_dim - 1);
|
||||
return ((in_dim + 2 * padding - ker - 1) / stride) + 1;
|
||||
inline int conv_out_axis_size(int in_dim, int wt_dim, int stride, int padding) {
|
||||
return ((in_dim + padding - wt_dim) / stride) + 1;
|
||||
}
|
||||
|
||||
// Conv helpers
|
||||
inline int dilate_size(int dim, int dil) {
|
||||
return 1 + dil * (dim - 1);
|
||||
}
|
||||
|
||||
inline std::vector<int> conv_out_shape(
|
||||
const std::vector<int>& in_shape,
|
||||
const std::vector<int>& wt_shape,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& pads,
|
||||
const std::vector<int>& dilation) {
|
||||
const std::vector<int>& pads_lo,
|
||||
const std::vector<int>& pads_hi,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation) {
|
||||
int N = in_shape[0];
|
||||
int O = wt_shape[0];
|
||||
std::vector<int> out_shape(in_shape.size());
|
||||
int i = 0;
|
||||
out_shape[i++] = N;
|
||||
|
||||
int spatial_dims = in_shape.size() - 2;
|
||||
|
||||
if (strides.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid strides " << strides << "for " << spatial_dims
|
||||
<< "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (pads_lo.size() != spatial_dims || pads_hi.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid pading " << pads_lo << " | " << pads_hi << "for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (kernel_dilation.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid kernel dilation " << kernel_dilation << "for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (input_dilation.size() != spatial_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input dilation " << input_dilation << "for "
|
||||
<< spatial_dims << "D convolution.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
for (; i < in_shape.size() - 1; i++) {
|
||||
if (pads[i - 1] < 0) {
|
||||
if (kernel_dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Kernel dilation sizes must be positive."
|
||||
<< " Got kernel dilation " << kernel_dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (input_dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Input dilation sizes must be positive."
|
||||
<< " Got input dilation " << input_dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Padding sizes must be non-negative."
|
||||
<< " Got padding " << pads << ".";
|
||||
<< " Got padding " << pads_lo << " | " << pads_hi << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@ -2707,22 +2782,19 @@ inline std::vector<int> conv_out_shape(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (dilation[i - 1] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Dilation sizes must be positive."
|
||||
<< " Got dilation " << dilation << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
int kd = dilate_size(wt_shape[i], kernel_dilation[i - 1]);
|
||||
int id = dilate_size(in_shape[i], input_dilation[i - 1]);
|
||||
|
||||
out_shape[i] = conv_out_axis_size(
|
||||
in_shape[i], wt_shape[i], strides[i - 1], pads[i - 1], dilation[i - 1]);
|
||||
id, kd, strides[i - 1], pads_lo[i - 1] + pads_hi[i - 1]);
|
||||
|
||||
if (out_shape[i] <= 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Spatial dimensions of input after padding "
|
||||
<< " cannot be smaller than weight spatial dimensions."
|
||||
<< " Got input with shape " << in_shape << " and padding " << pads
|
||||
<< " for weight of shape " << wt_shape << ".";
|
||||
<< " Got error at axis " << i << " for input with shape " << in_shape
|
||||
<< ", padding low " << pads_lo << ", padding high " << pads_hi
|
||||
<< ", and weight of shape " << wt_shape << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
@ -2777,43 +2849,16 @@ array conv1d(
|
||||
int dilation /* = 1 */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Run checks
|
||||
if (groups != 1) {
|
||||
throw std::invalid_argument("[conv1d] Cannot handle groups != 1 yet");
|
||||
}
|
||||
if (dilation != 1) {
|
||||
throw std::invalid_argument("[conv1d] Cannot handle dilation != 1 yet");
|
||||
}
|
||||
|
||||
// Run checks
|
||||
run_conv_checks(in_, wt_, 1);
|
||||
|
||||
auto in = in_;
|
||||
auto wt = wt_;
|
||||
|
||||
// Type promotion
|
||||
auto out_type = promote_types(in.dtype(), wt.dtype());
|
||||
in = astype(in, out_type, s);
|
||||
wt = astype(wt, out_type, s);
|
||||
|
||||
std::vector<int> strides_vec = {stride};
|
||||
std::vector<int> padding_vec = {padding};
|
||||
std::vector<int> dilation_vec = {dilation};
|
||||
|
||||
// Get output shapes
|
||||
std::vector<int> out_shape = conv_out_shape(
|
||||
in.shape(), wt.shape(), strides_vec, padding_vec, dilation_vec);
|
||||
|
||||
return array(
|
||||
out_shape,
|
||||
in.dtype(),
|
||||
std::make_unique<Convolution>(
|
||||
to_stream(s),
|
||||
padding_vec,
|
||||
strides_vec,
|
||||
dilation_vec,
|
||||
std::vector<int>(1, 1)),
|
||||
{in, wt});
|
||||
return conv_general(
|
||||
/* const array& input = */ in_,
|
||||
/* const array& weight = */ wt_,
|
||||
/* std::vector<int> stride = */ {stride},
|
||||
/* std::vector<int> padding = */ {padding},
|
||||
/* std::vector<int> kernel_dilation = */ {dilation},
|
||||
/* std::vector<int> input_dilation = */ {1},
|
||||
/* int groups = */ groups,
|
||||
/* bool flip = */ false,
|
||||
s);
|
||||
}
|
||||
|
||||
/** 2D convolution with a filter */
|
||||
@ -2825,42 +2870,98 @@ array conv2d(
|
||||
const std::pair<int, int>& dilation /* = {1, 1} */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return conv_general(
|
||||
/* const array& input = */ in_,
|
||||
/* const array& weight = */ wt_,
|
||||
/* std::vector<int> stride = */ {stride.first, stride.second},
|
||||
/* std::vector<int> padding = */ {padding.first, padding.second},
|
||||
/* std::vector<int> kernel_dilation = */
|
||||
{dilation.first, dilation.second},
|
||||
/* std::vector<int> input_dilation = */ {1, 1},
|
||||
/* int groups = */ groups,
|
||||
/* bool flip = */ false,
|
||||
s);
|
||||
}
|
||||
|
||||
/** General convolution with a filter */
|
||||
array conv_general(
|
||||
array in,
|
||||
array wt,
|
||||
std::vector<int> stride /* = {} */,
|
||||
std::vector<int> padding_lo /* = {} */,
|
||||
std::vector<int> padding_hi /* = {} */,
|
||||
std::vector<int> kernel_dilation /* = {} */,
|
||||
std::vector<int> input_dilation /* = {} */,
|
||||
int groups /* = 1 */,
|
||||
bool flip /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Run checks
|
||||
if (groups != 1) {
|
||||
throw std::invalid_argument("[conv2d] Cannot handle groups != 1 yet");
|
||||
throw std::invalid_argument("[conv] Cannot handle groups != 1 yet");
|
||||
}
|
||||
if (dilation.first != 1 || dilation.second != 1) {
|
||||
throw std::invalid_argument("[conv2d] Cannot handle dilation != 1 yet");
|
||||
|
||||
int spatial_dims = in.ndim() - 2;
|
||||
|
||||
if (spatial_dims < 1 || spatial_dims > 2) {
|
||||
throw std::invalid_argument(
|
||||
"[conv] Can only work with inputs that have 1 or 2 spatial dimensions."
|
||||
" The inputs must be in the format [N, ..., C_in]");
|
||||
}
|
||||
|
||||
// Run checks
|
||||
run_conv_checks(in_, wt_, 2);
|
||||
|
||||
auto in = in_;
|
||||
auto wt = wt_;
|
||||
run_conv_checks(in, wt, spatial_dims);
|
||||
|
||||
// Type promotion
|
||||
auto out_type = promote_types(in.dtype(), wt.dtype());
|
||||
in = astype(in, out_type, s);
|
||||
wt = astype(wt, out_type, s);
|
||||
|
||||
std::vector<int> strides_vec = {stride.first, stride.second};
|
||||
std::vector<int> padding_vec = {padding.first, padding.second};
|
||||
std::vector<int> dilation_vec = {dilation.first, dilation.second};
|
||||
if (stride.size() <= 1) {
|
||||
int stride_int = stride.size() ? stride[0] : 1;
|
||||
stride = std::vector<int>(spatial_dims, stride_int);
|
||||
}
|
||||
|
||||
if (padding_lo.size() <= 1) {
|
||||
int padding_int = padding_lo.size() ? padding_lo[0] : 0;
|
||||
padding_lo = std::vector<int>(spatial_dims, padding_int);
|
||||
}
|
||||
|
||||
if (padding_hi.size() <= 1) {
|
||||
int padding_int = padding_hi.size() ? padding_hi[0] : 0;
|
||||
padding_hi = std::vector<int>(spatial_dims, padding_int);
|
||||
}
|
||||
|
||||
if (kernel_dilation.size() <= 1) {
|
||||
int kernel_dilation_int = kernel_dilation.size() ? kernel_dilation[0] : 1;
|
||||
kernel_dilation = std::vector<int>(spatial_dims, kernel_dilation_int);
|
||||
}
|
||||
|
||||
if (input_dilation.size() <= 1) {
|
||||
int input_dilation_int = input_dilation.size() ? input_dilation[0] : 1;
|
||||
input_dilation = std::vector<int>(spatial_dims, input_dilation_int);
|
||||
}
|
||||
|
||||
// Get output shapes
|
||||
std::vector<int> out_shape = conv_out_shape(
|
||||
in.shape(), wt.shape(), strides_vec, padding_vec, dilation_vec);
|
||||
in.shape(),
|
||||
wt.shape(),
|
||||
stride,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
kernel_dilation,
|
||||
input_dilation);
|
||||
|
||||
return array(
|
||||
out_shape,
|
||||
in.dtype(),
|
||||
std::make_unique<Convolution>(
|
||||
to_stream(s),
|
||||
padding_vec,
|
||||
strides_vec,
|
||||
dilation_vec,
|
||||
std::vector<int>(2, 1)),
|
||||
stride,
|
||||
padding_lo,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
groups,
|
||||
flip),
|
||||
{in, wt});
|
||||
}
|
||||
|
||||
@ -3388,6 +3489,17 @@ array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return a;
|
||||
}
|
||||
|
||||
std::vector<array> atleast_1d(
|
||||
const std::vector<array>& arrays,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> out;
|
||||
out.reserve(arrays.size());
|
||||
for (const auto& a : arrays) {
|
||||
out.push_back(atleast_1d(a, s));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
switch (a.ndim()) {
|
||||
case 0:
|
||||
@ -3399,6 +3511,17 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> atleast_2d(
|
||||
const std::vector<array>& arrays,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> out;
|
||||
out.reserve(arrays.size());
|
||||
for (const auto& a : arrays) {
|
||||
out.push_back(atleast_2d(a, s));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
switch (a.ndim()) {
|
||||
case 0:
|
||||
@ -3411,4 +3534,16 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> atleast_3d(
|
||||
const std::vector<array>& arrays,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<array> out;
|
||||
out.reserve(arrays.size());
|
||||
for (const auto& a : arrays) {
|
||||
out.push_back(atleast_3d(a, s));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
46
mlx/ops.h
46
mlx/ops.h
@ -1026,6 +1026,43 @@ array cummin(
|
||||
|
||||
/** Convolution operations */
|
||||
|
||||
/** General convolution with a filter */
|
||||
array conv_general(
|
||||
array input,
|
||||
array weight,
|
||||
std::vector<int> stride = {},
|
||||
std::vector<int> padding_lo = {},
|
||||
std::vector<int> padding_hi = {},
|
||||
std::vector<int> kernel_dilation = {},
|
||||
std::vector<int> input_dilation = {},
|
||||
int groups = 1,
|
||||
bool flip = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** General convolution with a filter */
|
||||
inline array conv_general(
|
||||
const array& input,
|
||||
const array& weight,
|
||||
std::vector<int> stride = {},
|
||||
std::vector<int> padding = {},
|
||||
std::vector<int> kernel_dilation = {},
|
||||
std::vector<int> input_dilation = {},
|
||||
int groups = 1,
|
||||
bool flip = false,
|
||||
StreamOrDevice s = {}) {
|
||||
return conv_general(
|
||||
/* const array& input = */ input,
|
||||
/* const array& weight = */ weight,
|
||||
/* std::vector<int> stride = */ stride,
|
||||
/* std::vector<int> padding_lo = */ padding,
|
||||
/* std::vector<int> padding_hi = */ padding,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_dilation,
|
||||
/* std::vector<int> input_dilation = */ input_dilation,
|
||||
/* int groups = */ groups,
|
||||
/* bool flip = */ flip,
|
||||
/* StreamOrDevice s = */ s);
|
||||
}
|
||||
|
||||
/** 1D convolution with a filter */
|
||||
array conv1d(
|
||||
const array& input,
|
||||
@ -1123,7 +1160,16 @@ std::vector<array> depends(
|
||||
|
||||
/** convert an array to an atleast ndim array */
|
||||
array atleast_1d(const array& a, StreamOrDevice s = {});
|
||||
std::vector<array> atleast_1d(
|
||||
const std::vector<array>& a,
|
||||
StreamOrDevice s = {});
|
||||
array atleast_2d(const array& a, StreamOrDevice s = {});
|
||||
std::vector<array> atleast_2d(
|
||||
const std::vector<array>& a,
|
||||
StreamOrDevice s = {});
|
||||
array atleast_3d(const array& a, StreamOrDevice s = {});
|
||||
std::vector<array> atleast_3d(
|
||||
const std::vector<array>& a,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -48,6 +48,54 @@ std::tuple<array, array, int> vmap_binary_op(
|
||||
return {a, b, to_ax};
|
||||
}
|
||||
|
||||
std::tuple<array, array, array, int> vmap_ternary_op(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes,
|
||||
const Stream& stream) {
|
||||
assert(inputs.size() == 3);
|
||||
assert(axes.size() == 3);
|
||||
|
||||
auto a = inputs[0];
|
||||
auto b = inputs[1];
|
||||
auto c = inputs[2];
|
||||
int ndim = std::max(
|
||||
{a.ndim() + (axes[0] == -1),
|
||||
b.ndim() + (axes[1] == -1),
|
||||
c.ndim() + (axes[2] == -1)});
|
||||
|
||||
auto expand_dims = [stream, ndim](auto in) {
|
||||
auto shape = in.shape();
|
||||
shape.insert(shape.begin(), ndim - shape.size(), 1);
|
||||
return reshape(in, shape, stream);
|
||||
};
|
||||
|
||||
int to_ax = (ndim - a.ndim()) + axes[0];
|
||||
int from_ax1 = (ndim - b.ndim()) + axes[1];
|
||||
int from_ax2 = (ndim - c.ndim()) + axes[2];
|
||||
a = expand_dims(a);
|
||||
b = expand_dims(b);
|
||||
c = expand_dims(c);
|
||||
|
||||
auto find_tdims = [](auto x, int to_ax, int from_ax) {
|
||||
std::vector<int> tdims(x.ndim());
|
||||
std::iota(tdims.begin(), tdims.end(), 0);
|
||||
tdims.erase(tdims.begin() + from_ax);
|
||||
tdims.insert(tdims.begin() + to_ax, from_ax);
|
||||
return tdims;
|
||||
};
|
||||
|
||||
if (to_ax != from_ax1) {
|
||||
std::vector<int> tdims = find_tdims(b, to_ax, from_ax1);
|
||||
b = transpose(b, tdims, stream);
|
||||
}
|
||||
|
||||
if (to_ax != from_ax2) {
|
||||
std::vector<int> tdims = find_tdims(c, to_ax, from_ax2);
|
||||
c = transpose(c, tdims, stream);
|
||||
}
|
||||
return {a, b, c, to_ax};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<array> Primitive::jvp(
|
||||
@ -631,21 +679,13 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
|
||||
return axis_ == c_other.axis_;
|
||||
}
|
||||
|
||||
std::vector<array> Convolution::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 2);
|
||||
std::vector<array> grads;
|
||||
|
||||
// Collect info
|
||||
auto& in = primals[0];
|
||||
auto& wt = primals[1];
|
||||
auto cotan = cotangents[0];
|
||||
|
||||
int O = wt.shape(0);
|
||||
|
||||
array conv_weight_backward_patches(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
const array& cotan,
|
||||
const std::vector<int>& kernel_strides,
|
||||
const std::vector<int>& padding,
|
||||
StreamOrDevice s) {
|
||||
// Resolve Padded input shapes and strides
|
||||
std::vector<int> padding_starts(in.ndim(), 0);
|
||||
std::vector<int> padding_ends = in.shape();
|
||||
@ -653,9 +693,9 @@ std::vector<array> Convolution::vjp(
|
||||
|
||||
// padded shape
|
||||
for (int i = 1; i < in.ndim() - 1; i++) {
|
||||
in_padded_shape[i] += 2 * padding_[i - 1];
|
||||
padding_ends[i] += padding_[i - 1];
|
||||
padding_starts[i] += padding_[i - 1];
|
||||
in_padded_shape[i] += 2 * padding[i - 1];
|
||||
padding_ends[i] += padding[i - 1];
|
||||
padding_starts[i] += padding[i - 1];
|
||||
}
|
||||
|
||||
// padded strides (contiguous)
|
||||
@ -664,6 +704,12 @@ std::vector<array> Convolution::vjp(
|
||||
in_padded_strides[i] = in_padded_strides[i + 1] * in_padded_shape[i + 1];
|
||||
}
|
||||
|
||||
// Pad input
|
||||
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
||||
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
||||
auto in_padded =
|
||||
pad(in, padded_axes, padding, padding, array(0, in.dtype()), s);
|
||||
|
||||
// Resolve strided patches
|
||||
|
||||
// patches are shaped as
|
||||
@ -678,62 +724,108 @@ std::vector<array> Convolution::vjp(
|
||||
std::vector<size_t> patches_strides(patches_shape.size(), 1);
|
||||
patches_strides[0] = in_padded_strides[0];
|
||||
for (int i = 1; i < n_spatial_dim + 1; i++) {
|
||||
patches_strides[i] = in_padded_strides[i] * kernel_strides_[i - 1];
|
||||
patches_strides[i] = in_padded_strides[i] * kernel_strides[i - 1];
|
||||
}
|
||||
for (int i = 1; i < in.ndim(); i++) {
|
||||
patches_strides[n_spatial_dim + i] = in_padded_strides[i];
|
||||
}
|
||||
|
||||
// Reshape cotangents and weights for gemm
|
||||
cotan = reshape(cotangents[0], {-1, O}, stream());
|
||||
auto weight_reshaped = reshape(wt, {O, -1}, stream());
|
||||
// Make patches from in
|
||||
auto in_patches = as_strided(in_padded, patches_shape, patches_strides, 0, s);
|
||||
|
||||
// Prepare for matmul
|
||||
int O = wt.shape(0);
|
||||
auto cotan_mat = reshape(cotan, {-1, O}, s);
|
||||
in_patches = reshape(in_patches, {cotan_mat.shape(0), -1}, s);
|
||||
|
||||
auto grad = matmul(transpose(cotan_mat, {1, 0}, s), in_patches, s);
|
||||
grad = reshape(grad, wt.shape(), s);
|
||||
return grad;
|
||||
}
|
||||
|
||||
std::vector<array> Convolution::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 2);
|
||||
std::vector<array> grads;
|
||||
|
||||
// Collect info
|
||||
auto& in = primals[0];
|
||||
auto& wt = primals[1];
|
||||
auto& cotan = cotangents[0];
|
||||
|
||||
for (int a : argnums) {
|
||||
// Grads for input
|
||||
if (a == 0) {
|
||||
// Gemm with cotangents to get patches
|
||||
auto grad_patches = matmul(cotan, weight_reshaped, stream());
|
||||
std::vector<int> padding_lo = padding_;
|
||||
std::vector<int> padding_hi = padding_;
|
||||
|
||||
// Prepare base grad array to accumulate on
|
||||
int in_padded_size = in_padded_strides[0] * in_padded_shape[0];
|
||||
auto grad = zeros(
|
||||
{
|
||||
in_padded_size,
|
||||
},
|
||||
in.dtype(),
|
||||
for (int i = 0; i < padding_lo.size(); ++i) {
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding_lo[i] = wt_size - padding_[i] - 1;
|
||||
|
||||
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
||||
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
||||
padding_hi[i] = in_size - out_size + padding_[i];
|
||||
}
|
||||
|
||||
auto wt_trans = swapaxes(wt, 0, -1, stream());
|
||||
|
||||
auto grad = conv_general(
|
||||
/* const array& input = */ cotan,
|
||||
/* const array& weight = */ wt_trans,
|
||||
/* std::vector<int> stride = */ input_dilation_,
|
||||
/* std::vector<int> padding_lo = */ padding_lo,
|
||||
/* std::vector<int> padding_hi = */ padding_hi,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_dilation_,
|
||||
/* std::vector<int> input_dilation = */ kernel_strides_,
|
||||
/* int groups = */ 1,
|
||||
/* bool flip = */ !flip_,
|
||||
stream());
|
||||
|
||||
// Create index map
|
||||
int patches_size = grad_patches.size();
|
||||
auto idx = arange(in_padded_size, stream());
|
||||
idx = as_strided(idx, patches_shape, patches_strides, 0, stream());
|
||||
idx = reshape(idx, {patches_size}, stream());
|
||||
|
||||
// Flatten patches and scatter
|
||||
auto flat_patches = reshape(grad_patches, {patches_size, 1}, stream());
|
||||
grad = scatter_add(grad, idx, flat_patches, 0, stream());
|
||||
|
||||
// Reshape and slice away padding
|
||||
grad = reshape(grad, in_padded_shape, stream());
|
||||
grad = slice(grad, padding_starts, padding_ends, stream());
|
||||
|
||||
grads.push_back(grad);
|
||||
}
|
||||
// Grads for weight
|
||||
else if (a == 1) {
|
||||
// Make patches from in
|
||||
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
||||
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
||||
auto in_padded = pad(
|
||||
in, padded_axes, padding_, padding_, array(0, in.dtype()), stream());
|
||||
auto in_patches =
|
||||
as_strided(in_padded, patches_shape, patches_strides, 0, stream());
|
||||
in_patches = reshape(in_patches, {cotan.shape(0), -1}, stream());
|
||||
bool no_dilation = true;
|
||||
|
||||
auto grad =
|
||||
matmul(transpose(cotan, {1, 0}, stream()), in_patches, stream());
|
||||
grad = reshape(grad, wt.shape(), stream());
|
||||
grads.push_back(grad);
|
||||
for (int i = 0; i < input_dilation_.size(); i++) {
|
||||
no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1);
|
||||
}
|
||||
|
||||
if (no_dilation) {
|
||||
auto grad = conv_weight_backward_patches(
|
||||
in, wt, cotan, kernel_strides_, padding_, stream());
|
||||
grads.push_back(grad);
|
||||
} else {
|
||||
std::vector<int> padding_lo = padding_;
|
||||
std::vector<int> padding_hi = padding_;
|
||||
|
||||
for (int i = 0; i < padding_hi.size(); ++i) {
|
||||
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
||||
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
|
||||
}
|
||||
|
||||
auto in_trans = swapaxes(in, 0, -1, stream());
|
||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||
auto grad_trans = conv_general(
|
||||
/* const array& input = */ in_trans,
|
||||
/* const array& weight = */ cotan_trans,
|
||||
/* std::vector<int> stride = */ kernel_dilation_,
|
||||
/* std::vector<int> padding_lo = */ padding_lo,
|
||||
/* std::vector<int> padding_hi = */ padding_hi,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_strides_,
|
||||
/* std::vector<int> input_dilation = */ input_dilation_,
|
||||
/* int groups = */ 1,
|
||||
/* bool flip = */ flip_,
|
||||
stream());
|
||||
auto grad = swapaxes(grad_trans, 0, -1, stream());
|
||||
grads.push_back(grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -745,7 +837,8 @@ bool Convolution::is_equivalent(const Primitive& other) const {
|
||||
return padding_ == c_other.padding_ &&
|
||||
kernel_strides_ == c_other.kernel_strides_ &&
|
||||
kernel_dilation_ == c_other.kernel_dilation_ &&
|
||||
input_dilation_ == c_other.input_dilation_;
|
||||
input_dilation_ == c_other.input_dilation_ &&
|
||||
groups_ == c_other.groups_ && flip_ == c_other.flip_;
|
||||
}
|
||||
|
||||
std::vector<array> Copy::vjp(
|
||||
@ -1775,6 +1868,76 @@ std::pair<std::vector<array>, std::vector<int>> Multiply::vmap(
|
||||
return {{multiply(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Select::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 3);
|
||||
assert(tangents.size() == 3);
|
||||
|
||||
auto jvp_fun = [&](int i) {
|
||||
int arg = argnums[i];
|
||||
|
||||
if (arg == 0) {
|
||||
return zeros_like(primals[0], stream());
|
||||
} else if (arg == 1) {
|
||||
return multiply(
|
||||
astype(primals[0], tangents[1].dtype(), stream()),
|
||||
tangents[1],
|
||||
stream());
|
||||
} else {
|
||||
return multiply(
|
||||
astype(
|
||||
logical_not(primals[0], stream()), tangents[2].dtype(), stream()),
|
||||
tangents[2],
|
||||
stream());
|
||||
}
|
||||
};
|
||||
|
||||
array jvp = jvp_fun(argnums[0]);
|
||||
for (int i = 1; i < argnums.size(); i++) {
|
||||
jvp = add(jvp, jvp_fun(argnums[i]));
|
||||
}
|
||||
return {jvp};
|
||||
}
|
||||
|
||||
std::vector<array> Select::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 3);
|
||||
assert(cotangents.size() == 1);
|
||||
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
vjps.push_back(zeros_like(primals[0], stream()));
|
||||
} else if (arg == 1) {
|
||||
vjps.push_back(multiply(
|
||||
astype(primals[0], cotangents[0].dtype(), stream()),
|
||||
cotangents[0],
|
||||
stream()));
|
||||
} else if (arg == 2) {
|
||||
vjps.push_back(multiply(
|
||||
astype(
|
||||
logical_not(primals[0], stream()),
|
||||
cotangents[0].dtype(),
|
||||
stream()),
|
||||
cotangents[0],
|
||||
stream()));
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Select::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto [a, b, c, to_ax] = vmap_ternary_op(inputs, axes, stream());
|
||||
return {{where(a, b, c, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::vector<array> Negative::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
@ -1925,7 +2088,10 @@ std::vector<array> Power::vjp(
|
||||
primals[1],
|
||||
stream()));
|
||||
} else {
|
||||
vjps.push_back(multiply(log(primals[0], stream()), outputs[0], stream()));
|
||||
auto& exp = outputs[0];
|
||||
auto exp_vjp = multiply(log(primals[0], stream()), outputs[0], stream());
|
||||
// 0 * log 0 -> 0
|
||||
vjps.push_back(where(exp, exp_vjp, array(0.0f, exp.dtype()), stream()));
|
||||
}
|
||||
vjps.back() = multiply(cotangents[0], vjps.back(), stream());
|
||||
}
|
||||
|
@ -544,15 +544,19 @@ class Convolution : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Convolution(
|
||||
Stream stream,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation)
|
||||
const std::vector<int>& input_dilation,
|
||||
const int groups = 1,
|
||||
const bool flip = false)
|
||||
: UnaryPrimitive(stream),
|
||||
padding_(padding),
|
||||
kernel_strides_(kernel_strides),
|
||||
kernel_dilation_(kernel_dilation),
|
||||
input_dilation_(input_dilation){};
|
||||
input_dilation_(input_dilation),
|
||||
groups_(groups),
|
||||
flip_(flip){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -571,6 +575,8 @@ class Convolution : public UnaryPrimitive {
|
||||
std::vector<int> kernel_strides_;
|
||||
std::vector<int> kernel_dilation_;
|
||||
std::vector<int> input_dilation_;
|
||||
int groups_;
|
||||
bool flip_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
@ -719,6 +725,23 @@ class DivMod : public Primitive {
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
class Select : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Select(Stream stream) : UnaryPrimitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Select)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Remainder : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Remainder(Stream stream) : UnaryPrimitive(stream){};
|
||||
|
@ -37,6 +37,7 @@ from mlx.nn.layers.activations import (
|
||||
relu,
|
||||
relu6,
|
||||
selu,
|
||||
sigmoid,
|
||||
silu,
|
||||
softmax,
|
||||
softplus,
|
||||
@ -67,3 +68,4 @@ from mlx.nn.layers.transformer import (
|
||||
TransformerEncoder,
|
||||
TransformerEncoderLayer,
|
||||
)
|
||||
from mlx.nn.layers.upsample import Upsample
|
||||
|
@ -18,7 +18,7 @@ def _make_activation_module(f):
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def sigmoid(x):
|
||||
r"""Applies the element-wise function:
|
||||
r"""Applies the sigmoid function.
|
||||
|
||||
.. math::
|
||||
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
||||
@ -142,11 +142,11 @@ def log_sigmoid(x):
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def gelu(x):
|
||||
def gelu(x) -> mx.array:
|
||||
r"""Applies the Gaussian Error Linear Units function.
|
||||
|
||||
.. math::
|
||||
\\textrm{GELU}(x) = x * \Phi(x)
|
||||
\textrm{GELU}(x) = x * \Phi(x)
|
||||
|
||||
where :math:`\Phi(x)` is the Gaussian CDF.
|
||||
|
||||
@ -185,11 +185,15 @@ def gelu_fast_approx(x):
|
||||
|
||||
.. math::
|
||||
|
||||
x = x \sigma\left(1.773 x\right)
|
||||
x = x \sigma\left(1.702 x\right)
|
||||
|
||||
where :math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||
|
||||
References:
|
||||
- https://github.com/hendrycks/GELUs
|
||||
- https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return x * mx.sigmoid(1.773 * x)
|
||||
return x * mx.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
def glu(x: mx.array, axis: int = -1) -> mx.array:
|
||||
@ -199,7 +203,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
|
||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||
|
||||
.. math::
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
\textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``
|
||||
@ -260,6 +264,7 @@ def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def mish(x: mx.array) -> mx.array:
|
||||
r"""Applies the Mish function, element-wise.
|
||||
|
||||
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
||||
|
||||
Reference: https://arxiv.org/abs/1908.08681
|
||||
@ -297,7 +302,7 @@ class GLU(Module):
|
||||
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
|
||||
|
||||
.. math::
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
\textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``
|
||||
|
@ -7,6 +7,42 @@ import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
|
||||
if is_leaf_fn(model, value_key, value):
|
||||
return map_fn(value)
|
||||
|
||||
elif isinstance(value, Module):
|
||||
return {
|
||||
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||
for k, v in value.items()
|
||||
if filter_fn(value, k, v)
|
||||
}
|
||||
|
||||
elif isinstance(value, dict):
|
||||
nd = {}
|
||||
for k, v in v.items():
|
||||
tk = f"{value_key}.{k}"
|
||||
nd[k] = (
|
||||
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, v)
|
||||
else {}
|
||||
)
|
||||
return nd
|
||||
|
||||
elif isinstance(value, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(value):
|
||||
tk = f"{value_key}.{i}"
|
||||
nl.append(
|
||||
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
|
||||
if filter_fn(model, tk, vi)
|
||||
else {}
|
||||
)
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
|
||||
class Module(dict):
|
||||
"""Base class for building neural networks with MLX.
|
||||
|
||||
@ -98,10 +134,13 @@ class Module(dict):
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
|
||||
super(Module, self).__getattribute__(key)
|
||||
|
||||
def __setattr__(self, key: str, val: Any):
|
||||
self[key] = val
|
||||
if isinstance(val, (mx.array, dict, list, tuple)):
|
||||
self[key] = val
|
||||
else:
|
||||
super(Module, self).__setattr__(key, val)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
@ -245,31 +284,11 @@ class Module(dict):
|
||||
is_leaf_fn = is_leaf_fn or (
|
||||
lambda m, k, v: not isinstance(v, (Module, dict, list))
|
||||
)
|
||||
|
||||
def unwrap(vk, v):
|
||||
if is_leaf_fn(self, vk, v):
|
||||
return map_fn(v)
|
||||
|
||||
if isinstance(v, Module):
|
||||
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
|
||||
|
||||
if isinstance(v, dict):
|
||||
nd = {}
|
||||
for k, v in v.items():
|
||||
tk = f"{vk}.{k}"
|
||||
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
|
||||
return nd
|
||||
|
||||
if isinstance(v, list):
|
||||
nl = []
|
||||
for i, vi in enumerate(v):
|
||||
tk = f"{vk}.{i}"
|
||||
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
|
||||
return nl
|
||||
|
||||
raise RuntimeError("Unexpected leaf found while traversing the module")
|
||||
|
||||
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
|
||||
return {
|
||||
k: _unwrap(self, k, v, filter_fn, map_fn, is_leaf_fn)
|
||||
for k, v in self.items()
|
||||
if filter_fn(self, k, v)
|
||||
}
|
||||
|
||||
def parameters(self):
|
||||
"""Recursively return all the :class:`mlx.core.array` members of this Module
|
||||
|
205
python/mlx/nn/layers/upsample.py
Normal file
205
python/mlx/nn/layers/upsample.py
Normal file
@ -0,0 +1,205 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import product
|
||||
from typing import Literal, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
def _scaled_indices(N, scale, align_corners, dim, ndims):
|
||||
M = int(scale * N)
|
||||
if align_corners:
|
||||
indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1))
|
||||
else:
|
||||
step = 1 / scale
|
||||
start = ((M - 1) * step - N + 1) / 2
|
||||
indices = mx.arange(M, dtype=mx.float32) * step - start
|
||||
indices = mx.clip(indices, 0, N - 1)
|
||||
shape = [1] * ndims
|
||||
shape[dim] = -1
|
||||
|
||||
return indices.reshape(shape)
|
||||
|
||||
|
||||
def _nearest_indices(N, scale, dim, ndims):
|
||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
|
||||
|
||||
|
||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
|
||||
indices_l = mx.floor(indices)
|
||||
indices_r = mx.ceil(indices)
|
||||
weight = indices - indices_l
|
||||
weight = mx.expand_dims(weight, -1)
|
||||
|
||||
return (
|
||||
(indices_l.astype(mx.int32), 1 - weight),
|
||||
(indices_r.astype(mx.int32), weight),
|
||||
)
|
||||
|
||||
|
||||
def upsample_nearest(x: mx.array, scale_factor: Tuple):
|
||||
dims = x.ndim - 2
|
||||
if dims != len(scale_factor):
|
||||
raise ValueError("A scale needs to be provided for each spatial dimension")
|
||||
|
||||
# Integer scale_factors means we can simply expand-broadcast and reshape
|
||||
if tuple(map(int, scale_factor)) == scale_factor:
|
||||
shape = list(x.shape)
|
||||
for d in range(dims):
|
||||
shape.insert(2 + 2 * d, 1)
|
||||
x = x.reshape(shape)
|
||||
for d in range(dims):
|
||||
shape[2 + 2 * d] = int(scale_factor[d])
|
||||
x = mx.broadcast_to(x, shape)
|
||||
for d in range(dims):
|
||||
shape[d + 1] *= shape[d + 2]
|
||||
shape.pop(d + 2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
else:
|
||||
B, *N, C = x.shape
|
||||
indices = [slice(None)]
|
||||
for i, (n, s) in enumerate(zip(N, scale_factor)):
|
||||
indices.append(_nearest_indices(n, s, i, dims))
|
||||
indices = tuple(indices)
|
||||
|
||||
return x[indices]
|
||||
|
||||
|
||||
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
||||
dims = x.ndim - 2
|
||||
if dims != len(scale_factor):
|
||||
raise ValueError("A scale needs to be provided for each spatial dimension")
|
||||
|
||||
B, *N, C = x.shape
|
||||
|
||||
# Compute the sampling grid
|
||||
indices = []
|
||||
for i, (n, s) in enumerate(zip(N, scale_factor)):
|
||||
indices.append(_linear_indices(n, s, align_corners, i, dims))
|
||||
|
||||
# Sample and compute the weights
|
||||
samples = []
|
||||
weights = []
|
||||
for idx_weight in product(*indices):
|
||||
idx, weight = zip(*idx_weight)
|
||||
samples.append(x[(slice(None),) + idx])
|
||||
weights.append(reduce(operator.mul, weight))
|
||||
|
||||
# Interpolate
|
||||
return sum(wi * xi for wi, xi in zip(weights, samples))
|
||||
|
||||
|
||||
class Upsample(Module):
|
||||
r"""Upsample the input signal spatially.
|
||||
|
||||
The spatial dimensions are by convention dimensions ``1`` to ``x.ndim -
|
||||
2``. The first is the batch dimension and the last is the feature
|
||||
dimension.
|
||||
|
||||
For example, an audio signal would be 3D with 1 spatial dimension, an image
|
||||
4D with 2 and so on and so forth.
|
||||
|
||||
There are two upsampling algorithms implemented nearest neighbor upsampling
|
||||
and linear interpolation. Both can be applied to any number of spatial
|
||||
dimensions and the linear interpolation will be bilinear, trilinear etc
|
||||
when applied to more than one spatial dimension.
|
||||
|
||||
.. note::
|
||||
When using one of the linear interpolation modes the ``align_corners``
|
||||
argument changes how the corners are treated in the input image. If
|
||||
``align_corners=True`` then the top and left edge of the input and
|
||||
output will be matching as will the bottom right edge.
|
||||
|
||||
Parameters:
|
||||
scale_factor (float or tuple): The multiplier for the spatial size.
|
||||
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
|
||||
Otherwise, the number of scale factors provided must match the
|
||||
number of spatial dimensions.
|
||||
mode (str, optional): The upsampling algorithm, either ``"nearest"`` or
|
||||
``"linear"``. Default: ``"nearest"``.
|
||||
align_corners (bool, optional): Changes the way the corners are treated
|
||||
during ``"linear"`` upsampling. See the note above and the
|
||||
examples below for more details. Default: ``False``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>> x = mx.arange(1, 5).reshape((1, 2, 2, 1))
|
||||
>>> x
|
||||
array([[[[1],
|
||||
[2]],
|
||||
[[3],
|
||||
[4]]]], dtype=int32)
|
||||
>>> n = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
>>> n(x).squeeze()
|
||||
array([[1, 1, 2, 2],
|
||||
[1, 1, 2, 2],
|
||||
[3, 3, 4, 4],
|
||||
[3, 3, 4, 4]], dtype=int32)
|
||||
>>> b = nn.Upsample(scale_factor=2, mode='linear')
|
||||
>>> b(x).squeeze()
|
||||
array([[1, 1.25, 1.75, 2],
|
||||
[1.5, 1.75, 2.25, 2.5],
|
||||
[2.5, 2.75, 3.25, 3.5],
|
||||
[3, 3.25, 3.75, 4]], dtype=float32)
|
||||
>>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)
|
||||
>>> b(x).squeeze()
|
||||
array([[1, 1.33333, 1.66667, 2],
|
||||
[1.66667, 2, 2.33333, 2.66667],
|
||||
[2.33333, 2.66667, 3, 3.33333],
|
||||
[3, 3.33333, 3.66667, 4]], dtype=float32)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale_factor: Union[float, Tuple],
|
||||
mode: Literal["nearest", "linear"] = "nearest",
|
||||
align_corners: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if mode not in ["nearest", "linear"]:
|
||||
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
|
||||
if isinstance(scale_factor, (list, tuple)):
|
||||
self.scale_factor = tuple(map(float, scale_factor))
|
||||
else:
|
||||
self.scale_factor = float(scale_factor)
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def _extra_repr(self) -> str:
|
||||
return (
|
||||
f"scale_factor={self.scale_factor}, mode={self.mode!r}, "
|
||||
f"align_corners={self.align_corners}"
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
dims = x.ndim - 2
|
||||
if dims <= 0:
|
||||
raise ValueError(
|
||||
f"[Upsample] The input should have at least 1 spatial "
|
||||
f"dimension which means it should be at least 3D but "
|
||||
f"{x.ndim}D was provided"
|
||||
)
|
||||
|
||||
scale_factor = self.scale_factor
|
||||
if isinstance(scale_factor, tuple):
|
||||
if len(scale_factor) != dims:
|
||||
raise ValueError(
|
||||
f"[Upsample] One scale per spatial dimension is required but "
|
||||
f"scale_factor={scale_factor} and the number of spatial "
|
||||
f"dimensions were {dims}"
|
||||
)
|
||||
else:
|
||||
scale_factor = (scale_factor,) * dims
|
||||
|
||||
if self.mode == "nearest":
|
||||
return upsample_nearest(x, scale_factor)
|
||||
|
||||
else:
|
||||
return upsample_linear(x, scale_factor, self.align_corners)
|
@ -1,11 +1,12 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Callable, List
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def exponential_decay(init: float, decay_rate: float):
|
||||
def exponential_decay(init: float, decay_rate: float) -> Callable:
|
||||
r"""Make an exponential decay scheduler.
|
||||
|
||||
Args:
|
||||
@ -30,7 +31,7 @@ def exponential_decay(init: float, decay_rate: float):
|
||||
return schedule
|
||||
|
||||
|
||||
def step_decay(init: float, decay_rate: float, step_size: int):
|
||||
def step_decay(init: float, decay_rate: float, step_size: int) -> Callable:
|
||||
r"""Make a step decay scheduler.
|
||||
|
||||
Args:
|
||||
@ -57,7 +58,7 @@ def step_decay(init: float, decay_rate: float, step_size: int):
|
||||
return schedule
|
||||
|
||||
|
||||
def cosine_decay(init: float, decay_steps: int):
|
||||
def cosine_decay(init: float, decay_steps: int) -> Callable:
|
||||
r"""Make a cosine decay scheduler.
|
||||
|
||||
Args:
|
||||
@ -84,3 +85,73 @@ def cosine_decay(init: float, decay_steps: int):
|
||||
return init * decay
|
||||
|
||||
return scheduler
|
||||
|
||||
|
||||
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
|
||||
r"""Join multiple schedules to create a new schedule.
|
||||
|
||||
Args:
|
||||
schedules (list(Callable)): A list of schedules. Schedule :math:`i+1`
|
||||
receives a step count indicating the number of steps since
|
||||
the :math:`i`-th boundary.
|
||||
boundaries (list(int)): A list of integers of length ``len(schedules) - 1``
|
||||
that indicates when to transition between schedules.
|
||||
|
||||
Example:
|
||||
>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
|
||||
>>> cosine = optim.cosine_decay(1e-1, 200)
|
||||
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
|
||||
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||
>>> optimizer.learning_rate
|
||||
array(0.0, dtype=float32)
|
||||
>>> for _ in range(12): optimizer.update({}, {})
|
||||
...
|
||||
>>> optimizer.learning_rate
|
||||
array(0.0999938, dtype=float32)
|
||||
"""
|
||||
if len(schedules) == 0:
|
||||
raise ValueError("Must provide at least 1 schedule to join.")
|
||||
|
||||
if len(schedules) != len(boundaries) + 1:
|
||||
raise ValueError(
|
||||
f"Received {len(boundaries)} boundaries but "
|
||||
f"expected {len(schedules) - 1}."
|
||||
)
|
||||
|
||||
def schedule(step):
|
||||
output = schedules[0](step)
|
||||
for boundary, schedule in zip(boundaries, schedules[1:]):
|
||||
output = mx.where(step < boundary, output, schedule(step - boundary))
|
||||
return output
|
||||
|
||||
return schedule
|
||||
|
||||
|
||||
def linear_schedule(init: float, end: float, steps: int) -> Callable:
|
||||
r"""Make a linear scheduler.
|
||||
|
||||
Args:
|
||||
init (float): Initial value.
|
||||
end (float): Final value.
|
||||
steps (int): Number of steps to apply the schedule over. The value is
|
||||
``end`` for any steps beyond ``steps``.
|
||||
|
||||
Example:
|
||||
|
||||
>>> warmup = optim.linear_schedule(0, 1e-1, 100)
|
||||
>>> optimizer = optim.Adam(learning_rate=warmup)
|
||||
>>> optimizer.learning_rate
|
||||
array(0.0, dtype=float32)
|
||||
>>> for _ in range(101): optimizer.update({}, {})
|
||||
...
|
||||
>>> optimizer.learning_rate
|
||||
array(0.1, dtype=float32)
|
||||
"""
|
||||
if steps < 1:
|
||||
raise ValueError(f"steps must be greater than 0, but got {steps}.")
|
||||
|
||||
def step_fn(step):
|
||||
step = mx.minimum(step, steps)
|
||||
return step * ((end - init) / steps) + init
|
||||
|
||||
return step_fn
|
||||
|
@ -14,6 +14,7 @@ pybind11_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
@ -7,6 +7,7 @@
|
||||
#include <pybind11/numpy.h>
|
||||
|
||||
#include "python/src/indexing.h"
|
||||
#include "python/src/pybind11_numpy_fp16.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/ops.h"
|
||||
@ -350,55 +351,53 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
|
||||
shape.push_back(np_array.shape(i));
|
||||
}
|
||||
|
||||
// Get dtype
|
||||
auto type = np_array.dtype();
|
||||
|
||||
// Copy data and make array
|
||||
if (type.is(py::dtype::of<int>())) {
|
||||
if (py::isinstance<py::array_t<int32_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<int32_t>(
|
||||
np_array, shape, dtype.value_or(int32));
|
||||
} else if (type.is(py::dtype::of<uint32_t>())) {
|
||||
} else if (py::isinstance<py::array_t<uint32_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<uint32_t>(
|
||||
np_array, shape, dtype.value_or(uint32));
|
||||
} else if (type.is(py::dtype::of<bool>())) {
|
||||
} else if (py::isinstance<py::array_t<bool>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<bool>(
|
||||
np_array, shape, dtype.value_or(bool_));
|
||||
} else if (type.is(py::dtype::of<double>())) {
|
||||
} else if (py::isinstance<py::array_t<double>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<double>(
|
||||
np_array, shape, dtype.value_or(float32));
|
||||
} else if (type.is(py::dtype::of<float>())) {
|
||||
} else if (py::isinstance<py::array_t<float>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<float>(
|
||||
np_array, shape, dtype.value_or(float32));
|
||||
} else if (type.is(py::dtype("float16"))) {
|
||||
} else if (py::isinstance<py::array_t<float16_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<float>(
|
||||
np_array, shape, dtype.value_or(float16));
|
||||
} else if (type.is(py::dtype::of<uint8_t>())) {
|
||||
} else if (py::isinstance<py::array_t<uint8_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<uint8_t>(
|
||||
np_array, shape, dtype.value_or(uint8));
|
||||
} else if (type.is(py::dtype::of<uint16_t>())) {
|
||||
} else if (py::isinstance<py::array_t<uint16_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<uint16_t>(
|
||||
np_array, shape, dtype.value_or(uint16));
|
||||
} else if (type.is(py::dtype::of<uint64_t>())) {
|
||||
} else if (py::isinstance<py::array_t<uint64_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<uint64_t>(
|
||||
np_array, shape, dtype.value_or(uint64));
|
||||
} else if (type.is(py::dtype::of<int8_t>())) {
|
||||
} else if (py::isinstance<py::array_t<int8_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<int8_t>(
|
||||
np_array, shape, dtype.value_or(int8));
|
||||
} else if (type.is(py::dtype::of<int16_t>())) {
|
||||
} else if (py::isinstance<py::array_t<int16_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<int16_t>(
|
||||
np_array, shape, dtype.value_or(int16));
|
||||
} else if (type.is(py::dtype::of<int64_t>())) {
|
||||
} else if (py::isinstance<py::array_t<int64_t>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<int64_t>(
|
||||
np_array, shape, dtype.value_or(int64));
|
||||
} else if (type.is(py::dtype::of<std::complex<float>>())) {
|
||||
} else if (py::isinstance<py::array_t<std::complex<float>>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<std::complex<float>>(
|
||||
np_array, shape, dtype.value_or(complex64));
|
||||
} else if (type.is(py::dtype::of<std::complex<double>>())) {
|
||||
} else if (py::isinstance<py::array_t<std::complex<double>>>(np_array)) {
|
||||
return np_array_to_mlx_contiguous<std::complex<float>>(
|
||||
np_array, shape, dtype.value_or(complex64));
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Cannot convert numpy array of type " << type << " to mlx array.";
|
||||
msg << "Cannot convert numpy array of type " << np_array.dtype()
|
||||
<< " to mlx array.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
@ -5,18 +5,88 @@
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_metal(py::module_& m) {
|
||||
py::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||
metal.def("is_available", &metal::is_available);
|
||||
metal.def(
|
||||
"cache_enabled",
|
||||
&metal::cache_enabled,
|
||||
"check if metal buffer cache is enabled, default is true");
|
||||
"is_available",
|
||||
&metal::is_available,
|
||||
R"pbdoc(
|
||||
Check if the Metal back-end is available.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_cache_enabled",
|
||||
&metal::set_cache_enabled,
|
||||
"enable or disable metal buffer cache");
|
||||
"get_active_memory",
|
||||
&metal::get_active_memory,
|
||||
R"pbdoc(
|
||||
Get the actively used memory in bytes.
|
||||
|
||||
Note, this will not always match memory use reported by the system because
|
||||
it does not include cached memory buffers.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"get_peak_memory",
|
||||
&metal::get_peak_memory,
|
||||
R"pbdoc(
|
||||
Get the peak amount of used memory in bytes.
|
||||
|
||||
The maximum memory used is recorded from the beginning of the program
|
||||
execution.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"get_cache_memory",
|
||||
&metal::get_cache_memory,
|
||||
R"pbdoc(
|
||||
Get the cache size in bytes.
|
||||
|
||||
The cache includes memory not currently used that has not been returned
|
||||
to the system allocator.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_memory_limit",
|
||||
&metal::set_memory_limit,
|
||||
"limit"_a,
|
||||
py::kw_only(),
|
||||
"relaxed"_a = true,
|
||||
R"pbdoc(
|
||||
Set the memory limit.
|
||||
|
||||
Memory allocations will wait on scheduled tasks to complete if the limit
|
||||
is exceeded. If there are no more scheduled tasks an error will be raised
|
||||
if ``relaxed`` is ``False``. Otherwise memory will be allocated
|
||||
(including the potential for swap) if ``relaxed`` is ``True``.
|
||||
|
||||
The memory limit defaults to 1.5 times the maximum recommended working set
|
||||
size reported by the device.
|
||||
|
||||
Args:
|
||||
limit (int): Memory limit in bytes.
|
||||
relaxed (bool, optional): If `False`` an error is raised if the limit
|
||||
is exceeded. Default: ``True``
|
||||
|
||||
Returns:
|
||||
int: The previous memory limit in bytes.
|
||||
)pbdoc");
|
||||
metal.def(
|
||||
"set_cache_limit",
|
||||
&metal::set_cache_limit,
|
||||
"limit"_a,
|
||||
R"pbdoc(
|
||||
Set the free cache limit.
|
||||
|
||||
If using more than the given limit, free memory will be reclaimed
|
||||
from the cache on the next allocation. To disable the cache, set
|
||||
the limit to ``0``.
|
||||
|
||||
The cache limit defaults to the memory limit. See
|
||||
:func:`set_memory_limit` for more details.
|
||||
|
||||
Args:
|
||||
limit (int): The cache limit in bytes.
|
||||
|
||||
Returns:
|
||||
int: The previous cache limit in bytes.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -3081,7 +3081,7 @@ void init_ops(py::module_& m) {
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: Union[int, Tuple[int, int]] = 1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
2D convolution over an input with several channels
|
||||
|
||||
@ -3105,6 +3105,114 @@ void init_ops(py::module_& m) {
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"conv_general",
|
||||
[](const array& input,
|
||||
const array& weight,
|
||||
const std::variant<int, std::vector<int>>& stride,
|
||||
const std::variant<
|
||||
int,
|
||||
std::vector<int>,
|
||||
std::pair<std::vector<int>, std::vector<int>>>& padding,
|
||||
const std::variant<int, std::vector<int>>& kernel_dilation,
|
||||
const std::variant<int, std::vector<int>>& input_dilation,
|
||||
int groups,
|
||||
bool flip,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> stride_vec;
|
||||
std::vector<int> padding_lo_vec;
|
||||
std::vector<int> padding_hi_vec;
|
||||
std::vector<int> kernel_dilation_vec;
|
||||
std::vector<int> input_dilation_vec;
|
||||
|
||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||
stride_vec.push_back(*pv);
|
||||
} else {
|
||||
stride_vec = std::get<std::vector<int>>(stride);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&padding); pv) {
|
||||
padding_lo_vec.push_back(*pv);
|
||||
padding_hi_vec.push_back(*pv);
|
||||
} else if (auto pv = std::get_if<std::vector<int>>(&padding); pv) {
|
||||
padding_lo_vec = *pv;
|
||||
padding_hi_vec = *pv;
|
||||
} else {
|
||||
auto [pl, ph] =
|
||||
std::get<std::pair<std::vector<int>, std::vector<int>>>(padding);
|
||||
padding_lo_vec = pl;
|
||||
padding_hi_vec = ph;
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&kernel_dilation); pv) {
|
||||
kernel_dilation_vec.push_back(*pv);
|
||||
} else {
|
||||
kernel_dilation_vec = std::get<std::vector<int>>(kernel_dilation);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&input_dilation); pv) {
|
||||
input_dilation_vec.push_back(*pv);
|
||||
} else {
|
||||
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
|
||||
}
|
||||
|
||||
return conv_general(
|
||||
/* const array& input = */ input,
|
||||
/* const array& weight = */ weight,
|
||||
/* std::vector<int> stride = */ stride_vec,
|
||||
/* std::vector<int> padding_lo = */ padding_lo_vec,
|
||||
/* std::vector<int> padding_hi = */ padding_lo_vec,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_dilation_vec,
|
||||
/* std::vector<int> input_dilation = */ input_dilation_vec,
|
||||
/* int groups = */ groups,
|
||||
/* bool flip = */ flip,
|
||||
s);
|
||||
},
|
||||
"input"_a,
|
||||
"weight"_a,
|
||||
py::pos_only(),
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"kernel_dilation"_a = 1,
|
||||
"input_dilation"_a = 1,
|
||||
"groups"_a = 1,
|
||||
"flip"_a = false,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
conv_general(input: array, weight: array, /, stride: Union[int, List[int]] = 1, padding: Union[int, List[int], Tuple[List[int], List[int]]] = 0, kernel_dilation: Union[int, List[int]] = 1, input_dilation: Union[int, List[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
General convolution over an input with several channels
|
||||
|
||||
.. note::
|
||||
|
||||
* Only 1d and 2d convolutions are supported at the moment
|
||||
* the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): Input array of shape ``(N, ..., C_in)``
|
||||
weight (array): Weight array of shape ``(C_out, ..., C_in)``
|
||||
stride (int or list(int), optional): :obj:`list` with kernel strides.
|
||||
All spatial dimensions get the same stride if
|
||||
only one number is specified. Default: ``1``.
|
||||
padding (int, list(int), or tuple(list(int), list(int)), optional):
|
||||
:obj:`list` with input padding. All spatial dimensions get the same
|
||||
padding if only one number is specified. Default: ``0``.
|
||||
kernel_dilation (int or list(int), optional): :obj:`list` with
|
||||
kernel dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
input_dilation (int or list(int), optional): :obj:`list` with
|
||||
input dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
groups (int, optional): Input feature groups. Default: ``1``.
|
||||
flip (bool, optional): Flip the order in which the spatial dimensions of
|
||||
the weights are processed. Performs the cross-correlation operator when
|
||||
``flip`` is ``False`` and the convolution operator otherwise.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
array: The convolved array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"save",
|
||||
&mlx_save_helper,
|
||||
"file"_a,
|
||||
@ -3638,62 +3746,69 @@ void init_ops(py::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"atleast_1d",
|
||||
&atleast_1d,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
[](const py::args& arys, StreamOrDevice s) -> py::object {
|
||||
if (arys.size() == 1) {
|
||||
return py::cast(atleast_1d(arys[0].cast<array>(), s));
|
||||
}
|
||||
return py::cast(atleast_1d(arys.cast<std::vector<array>>(), s));
|
||||
},
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
|
||||
|
||||
Convert array to have at least one dimension.
|
||||
Convert all arrays to have at least one dimension.
|
||||
|
||||
args:
|
||||
a (array): Input array
|
||||
Args:
|
||||
*arys: Input arrays.
|
||||
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
||||
|
||||
Returns:
|
||||
array: An array with at least one dimension.
|
||||
|
||||
array or list(array): An array or list of arrays with at least one dimension.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"atleast_2d",
|
||||
&atleast_2d,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
[](const py::args& arys, StreamOrDevice s) -> py::object {
|
||||
if (arys.size() == 1) {
|
||||
return py::cast(atleast_2d(arys[0].cast<array>(), s));
|
||||
}
|
||||
return py::cast(atleast_2d(arys.cast<std::vector<array>>(), s));
|
||||
},
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
|
||||
|
||||
Convert array to have at least two dimensions.
|
||||
Convert all arrays to have at least two dimensions.
|
||||
|
||||
args:
|
||||
a (array): Input array
|
||||
Args:
|
||||
*arys: Input arrays.
|
||||
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
||||
|
||||
Returns:
|
||||
array: An array with at least two dimensions.
|
||||
|
||||
array or list(array): An array or list of arrays with at least two dimensions.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"atleast_3d",
|
||||
&atleast_3d,
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
[](const py::args& arys, StreamOrDevice s) -> py::object {
|
||||
if (arys.size() == 1) {
|
||||
return py::cast(atleast_3d(arys[0].cast<array>(), s));
|
||||
}
|
||||
return py::cast(atleast_3d(arys.cast<std::vector<array>>(), s));
|
||||
},
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array
|
||||
atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
|
||||
|
||||
Convert array to have at least three dimensions.
|
||||
Convert all arrays to have at least three dimensions.
|
||||
|
||||
args:
|
||||
a (array): Input array
|
||||
Args:
|
||||
*arys: Input arrays.
|
||||
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
|
||||
|
||||
Returns:
|
||||
array: An array with at least three dimensions.
|
||||
|
||||
array or list(array): An array or list of arrays with at least three dimensions.
|
||||
)pbdoc");
|
||||
}
|
||||
|
60
python/src/pybind11_numpy_fp16.h
Normal file
60
python/src/pybind11_numpy_fp16.h
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
// A patch to get float16_t to work with pybind11 numpy arrays
|
||||
// Derived from:
|
||||
// https://github.com/pybind/pybind11/issues/1776#issuecomment-492230679
|
||||
|
||||
#include <pybind11/numpy.h>
|
||||
|
||||
namespace pybind11::detail {
|
||||
|
||||
template <typename T>
|
||||
struct npy_scalar_caster {
|
||||
PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
|
||||
using Array = array_t<T>;
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
// Taken from Eigen casters. Permits either scalar dtype or scalar array.
|
||||
handle type = dtype::of<T>().attr("type"); // Could make more efficient.
|
||||
if (!convert && !isinstance<Array>(src) && !isinstance(src, type))
|
||||
return false;
|
||||
Array tmp = Array::ensure(src);
|
||||
if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
|
||||
this->value = *tmp.data();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static handle cast(T src, return_value_policy, handle) {
|
||||
Array tmp({1});
|
||||
tmp.mutable_at(0) = src;
|
||||
tmp.resize({});
|
||||
// You could also just return the array if you want a scalar array.
|
||||
object scalar = tmp[tuple()];
|
||||
return scalar.release();
|
||||
}
|
||||
};
|
||||
|
||||
// Similar to enums in `pybind11/numpy.h`. Determined by doing:
|
||||
// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
|
||||
constexpr int NPY_FLOAT16 = 23;
|
||||
|
||||
// Kinda following:
|
||||
// https://github.com/pybind/pybind11/blob/9bb3313162c0b856125e481ceece9d8faa567716/include/pybind11/numpy.h#L1000
|
||||
template <>
|
||||
struct npy_format_descriptor<float16_t> {
|
||||
static constexpr auto name = _("float16");
|
||||
static pybind11::dtype dtype() {
|
||||
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
|
||||
return reinterpret_borrow<pybind11::dtype>(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<float16_t> : npy_scalar_caster<float16_t> {
|
||||
static constexpr auto name = _("float16");
|
||||
};
|
||||
|
||||
} // namespace pybind11::detail
|
@ -11,6 +11,7 @@
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
@ -30,246 +31,6 @@ std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
|
||||
return vals;
|
||||
}
|
||||
|
||||
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
|
||||
std::function<void(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree) ||
|
||||
py::isinstance<py::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
for (auto item : py::cast<py::dict>(subtree)) {
|
||||
recurse(item.second);
|
||||
}
|
||||
} else {
|
||||
visitor(subtree);
|
||||
}
|
||||
};
|
||||
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void validate_subtrees(const std::vector<py::object>& subtrees) {
|
||||
int len = py::cast<T>(subtrees[0]).size();
|
||||
for (auto& subtree : subtrees) {
|
||||
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
|
||||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
const std::vector<py::object>& trees,
|
||||
std::function<py::object(const std::vector<py::object>&)> transform) {
|
||||
std::function<py::object(const std::vector<py::object>&)> recurse;
|
||||
|
||||
recurse = [&](const std::vector<py::object>& subtrees) {
|
||||
if (py::isinstance<py::list>(subtrees[0])) {
|
||||
py::list l;
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
|
||||
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::list>(subtrees[j])) {
|
||||
items[j] = py::cast<py::list>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l.append(recurse(items));
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtrees[0])) {
|
||||
// Check the rest of the subtrees
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
int len = py::cast<py::tuple>(subtrees[0]).size();
|
||||
py::tuple l(len);
|
||||
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::tuple>(subtrees[j])) {
|
||||
items[j] = py::cast<py::tuple>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l[i] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::dict>(subtrees[0])) {
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
|
||||
py::dict d;
|
||||
for (auto item : py::cast<py::dict>(subtrees[0])) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::dict>(subtrees[j])) {
|
||||
auto subdict = py::cast<py::dict>(subtrees[j]);
|
||||
if (!subdict.contains(item.first)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Tree is not a valid prefix tree of the first tree.");
|
||||
}
|
||||
items[j] = subdict[item.first];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
d[item.first] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else {
|
||||
return transform(subtrees);
|
||||
}
|
||||
};
|
||||
return recurse(trees);
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> transform) {
|
||||
return tree_map({tree}, [&](std::vector<py::object> inputs) {
|
||||
return transform(inputs[0]);
|
||||
});
|
||||
}
|
||||
|
||||
void tree_visit_update(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> visitor) {
|
||||
std::function<py::object(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree)) {
|
||||
auto l = py::cast<py::list>(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
return py::cast<py::object>(subtree);
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
auto d = py::cast<py::dict>(subtree);
|
||||
for (auto item : d) {
|
||||
d[item.first] = recurse(item.second);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else if (py::isinstance<array>(subtree)) {
|
||||
return visitor(subtree);
|
||||
} else {
|
||||
return py::cast<py::object>(subtree);
|
||||
}
|
||||
};
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
// Fill a pytree (recursive dict or list of dict or list)
|
||||
// in place with the given arrays
|
||||
// Non dict or list nodes are ignored
|
||||
void tree_fill(py::object& tree, const std::vector<array>& values) {
|
||||
size_t index = 0;
|
||||
tree_visit_update(
|
||||
tree, [&](py::handle node) { return py::cast(values[index++]); });
|
||||
}
|
||||
|
||||
// Replace all the arrays from the src values with the dst values in the tree
|
||||
void tree_replace(
|
||||
py::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst) {
|
||||
std::unordered_map<uintptr_t, array> src_to_dst;
|
||||
for (int i = 0; i < src.size(); ++i) {
|
||||
src_to_dst.insert({src[i].id(), dst[i]});
|
||||
}
|
||||
tree_visit_update(tree, [&](py::handle node) {
|
||||
auto arr = py::cast<array>(node);
|
||||
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
||||
return py::cast(it->second);
|
||||
}
|
||||
return py::cast(arr);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
|
||||
std::vector<array> flat_tree;
|
||||
|
||||
tree_visit(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
} else if (strict) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
|
||||
return flat_tree;
|
||||
}
|
||||
|
||||
py::object tree_unflatten(
|
||||
py::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index = 0) {
|
||||
return tree_map(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
py::object structure_sentinel() {
|
||||
static py::object sentinel;
|
||||
|
||||
if (sentinel.ptr() == nullptr) {
|
||||
sentinel = py::capsule(&sentinel);
|
||||
// probably not needed but this should make certain that we won't ever
|
||||
// delete the sentinel
|
||||
sentinel.inc_ref();
|
||||
}
|
||||
|
||||
return sentinel;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||
py::object tree,
|
||||
bool strict = true) {
|
||||
auto sentinel = structure_sentinel();
|
||||
std::vector<array> flat_tree;
|
||||
auto structure = tree_map(
|
||||
tree,
|
||||
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
return sentinel;
|
||||
} else if (!strict) {
|
||||
return py::cast<py::object>(obj);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
|
||||
return {flat_tree, structure};
|
||||
}
|
||||
|
||||
py::object tree_unflatten_from_structure(
|
||||
py::object structure,
|
||||
const std::vector<array>& values,
|
||||
int index = 0) {
|
||||
auto sentinel = structure_sentinel();
|
||||
return tree_map(structure, [&](py::handle obj) {
|
||||
if (obj.is(sentinel)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
auto validate_argnums_argnames(
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
@ -582,9 +343,69 @@ struct PyCompiledFun {
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
|
||||
auto inputs = tree_flatten(args, false);
|
||||
// Flat array inputs
|
||||
std::vector<array> inputs;
|
||||
|
||||
auto compile_fun = [this, &args, &kwargs, num_args = inputs.size()](
|
||||
// Compilation constants which includes the tree structure of the arguments
|
||||
std::vector<uint64_t> constants;
|
||||
|
||||
// Reserve some large primes to signify the presence of an array, a list or
|
||||
// a dict in order to encode the structure of the pytree. We choose primes
|
||||
// to reduce slightly the chances of these numbers occurring by a
|
||||
// multiplication as values in the constants list.
|
||||
constexpr uint64_t array_identifier = 18446744073709551557UL;
|
||||
constexpr uint64_t list_identifier = 18446744073709551533UL;
|
||||
constexpr uint64_t dict_identifier = 18446744073709551521UL;
|
||||
|
||||
// Flatten the tree with hashed constants and structure
|
||||
std::function<void(py::handle)> recurse;
|
||||
recurse = [&](py::handle obj) {
|
||||
if (py::isinstance<py::list>(obj)) {
|
||||
auto l = py::cast<py::list>(obj);
|
||||
constants.push_back(list_identifier);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
recurse(l[i]);
|
||||
}
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
auto l = py::cast<py::tuple>(obj);
|
||||
constants.push_back(list_identifier);
|
||||
for (auto item : obj) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (py::isinstance<py::dict>(obj)) {
|
||||
auto d = py::cast<py::dict>(obj);
|
||||
constants.push_back(dict_identifier);
|
||||
for (auto item : d) {
|
||||
auto r = py::hash(item.first);
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
recurse(item.second);
|
||||
}
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
inputs.push_back(py::cast<array>(obj));
|
||||
constants.push_back(array_identifier);
|
||||
} else if (py::isinstance<py::str>(obj)) {
|
||||
auto r = py::hash(obj);
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
auto r = obj.cast<int64_t>();
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
} else if (py::isinstance<py::float_>(obj)) {
|
||||
auto r = obj.cast<double>();
|
||||
constants.push_back(*reinterpret_cast<uint64_t*>(&r));
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[compile] Function arguments must be trees of arrays "
|
||||
<< "or constants (floats, ints, or strings), but received "
|
||||
<< "type " << obj.get_type() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
};
|
||||
|
||||
recurse(args);
|
||||
int num_args = inputs.size();
|
||||
recurse(kwargs);
|
||||
|
||||
auto compile_fun = [this, &args, &kwargs, num_args](
|
||||
const std::vector<array>& a) {
|
||||
// Put tracers into captured inputs
|
||||
std::vector<array> flat_in_captures;
|
||||
@ -619,14 +440,6 @@ struct PyCompiledFun {
|
||||
return outputs;
|
||||
};
|
||||
|
||||
{
|
||||
auto flat_kwargs = tree_flatten(kwargs, false);
|
||||
inputs.insert(
|
||||
inputs.end(),
|
||||
std::make_move_iterator(flat_kwargs.begin()),
|
||||
std::make_move_iterator(flat_kwargs.end()));
|
||||
}
|
||||
|
||||
if (!py::isinstance<py::none>(captured_inputs)) {
|
||||
auto flat_in_captures = tree_flatten(captured_inputs, false);
|
||||
inputs.insert(
|
||||
@ -635,36 +448,6 @@ struct PyCompiledFun {
|
||||
std::make_move_iterator(flat_in_captures.end()));
|
||||
}
|
||||
|
||||
// Collect the compilation constants
|
||||
std::vector<uint64_t> constants;
|
||||
auto value_hash = [](py::handle o) -> std::optional<uint64_t> {
|
||||
// Consider expanding tuples to their contents including start and end
|
||||
// ids
|
||||
if (py::isinstance<py::tuple>(o) || py::isinstance<py::str>(o)) {
|
||||
auto r = py::hash(o);
|
||||
return *reinterpret_cast<uint64_t*>(&r);
|
||||
} else if (py::isinstance<py::int_>(o)) {
|
||||
auto r = o.cast<int64_t>();
|
||||
return *reinterpret_cast<uint64_t*>(&r);
|
||||
} else if (py::isinstance<py::float_>(o)) {
|
||||
auto r = o.cast<double>();
|
||||
return *reinterpret_cast<uint64_t*>(&r);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
if (auto h = value_hash(args[i]); h.has_value()) {
|
||||
constants.push_back(*h);
|
||||
}
|
||||
}
|
||||
for (auto& pair : kwargs) {
|
||||
if (auto h = value_hash(pair.second); h.has_value()) {
|
||||
constants.push_back(*value_hash(pair.first));
|
||||
constants.push_back(*h);
|
||||
}
|
||||
}
|
||||
|
||||
// Compile and call
|
||||
auto outputs =
|
||||
detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
|
||||
@ -1017,7 +800,38 @@ void init_transforms(py::module_& m) {
|
||||
const py::object& inputs,
|
||||
const py::object& outputs,
|
||||
bool shapeless) {
|
||||
return py::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless});
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
std::ostringstream doc;
|
||||
auto name = fun.attr("__name__").cast<std::string>();
|
||||
doc << name;
|
||||
|
||||
// Try to get the signature
|
||||
auto inspect = py::module::import("inspect");
|
||||
if (!inspect.attr("isbuiltin")(fun).cast<bool>()) {
|
||||
doc << inspect.attr("signature")(fun)
|
||||
.attr("__str__")()
|
||||
.cast<std::string>();
|
||||
}
|
||||
|
||||
// Try to get the doc string
|
||||
if (auto d = fun.attr("__doc__"); py::isinstance<py::str>(d)) {
|
||||
doc << "\n\n";
|
||||
auto dstr = d.cast<std::string>();
|
||||
// Add spaces to match first line indentation with remainder of
|
||||
// docstring
|
||||
int i = 0;
|
||||
for (int i = dstr.size() - 1; i >= 0 && dstr[i] == ' '; i--) {
|
||||
doc << ' ';
|
||||
}
|
||||
doc << dstr;
|
||||
}
|
||||
auto doc_str = doc.str();
|
||||
return py::cpp_function(
|
||||
PyCompiledFun{fun, inputs, outputs, shapeless},
|
||||
py::name(name.c_str()),
|
||||
py::doc(doc_str.c_str()));
|
||||
},
|
||||
"fun"_a,
|
||||
"inputs"_a = std::nullopt,
|
||||
|
243
python/src/trees.cpp
Normal file
243
python/src/trees.cpp
Normal file
@ -0,0 +1,243 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "python/src/trees.h"
|
||||
|
||||
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
|
||||
std::function<void(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree) ||
|
||||
py::isinstance<py::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
for (auto item : py::cast<py::dict>(subtree)) {
|
||||
recurse(item.second);
|
||||
}
|
||||
} else {
|
||||
visitor(subtree);
|
||||
}
|
||||
};
|
||||
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void validate_subtrees(const std::vector<py::object>& subtrees) {
|
||||
int len = py::cast<T>(subtrees[0]).size();
|
||||
for (auto& subtree : subtrees) {
|
||||
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
|
||||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
const std::vector<py::object>& trees,
|
||||
std::function<py::object(const std::vector<py::object>&)> transform) {
|
||||
std::function<py::object(const std::vector<py::object>&)> recurse;
|
||||
|
||||
recurse = [&](const std::vector<py::object>& subtrees) {
|
||||
if (py::isinstance<py::list>(subtrees[0])) {
|
||||
py::list l;
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
|
||||
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::list>(subtrees[j])) {
|
||||
items[j] = py::cast<py::list>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l.append(recurse(items));
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtrees[0])) {
|
||||
// Check the rest of the subtrees
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
int len = py::cast<py::tuple>(subtrees[0]).size();
|
||||
py::tuple l(len);
|
||||
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::tuple>(subtrees[j])) {
|
||||
items[j] = py::cast<py::tuple>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l[i] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::dict>(subtrees[0])) {
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
|
||||
py::dict d;
|
||||
for (auto item : py::cast<py::dict>(subtrees[0])) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::dict>(subtrees[j])) {
|
||||
auto subdict = py::cast<py::dict>(subtrees[j]);
|
||||
if (!subdict.contains(item.first)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Tree is not a valid prefix tree of the first tree.");
|
||||
}
|
||||
items[j] = subdict[item.first];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
d[item.first] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else {
|
||||
return transform(subtrees);
|
||||
}
|
||||
};
|
||||
return recurse(trees);
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> transform) {
|
||||
return tree_map({tree}, [&](std::vector<py::object> inputs) {
|
||||
return transform(inputs[0]);
|
||||
});
|
||||
}
|
||||
|
||||
void tree_visit_update(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> visitor) {
|
||||
std::function<py::object(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree)) {
|
||||
auto l = py::cast<py::list>(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
return py::cast<py::object>(subtree);
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
auto d = py::cast<py::dict>(subtree);
|
||||
for (auto item : d) {
|
||||
d[item.first] = recurse(item.second);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else if (py::isinstance<array>(subtree)) {
|
||||
return visitor(subtree);
|
||||
} else {
|
||||
return py::cast<py::object>(subtree);
|
||||
}
|
||||
};
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
// Fill a pytree (recursive dict or list of dict or list)
|
||||
// in place with the given arrays
|
||||
// Non dict or list nodes are ignored
|
||||
void tree_fill(py::object& tree, const std::vector<array>& values) {
|
||||
size_t index = 0;
|
||||
tree_visit_update(
|
||||
tree, [&](py::handle node) { return py::cast(values[index++]); });
|
||||
}
|
||||
|
||||
// Replace all the arrays from the src values with the dst values in the tree
|
||||
void tree_replace(
|
||||
py::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst) {
|
||||
std::unordered_map<uintptr_t, array> src_to_dst;
|
||||
for (int i = 0; i < src.size(); ++i) {
|
||||
src_to_dst.insert({src[i].id(), dst[i]});
|
||||
}
|
||||
tree_visit_update(tree, [&](py::handle node) {
|
||||
auto arr = py::cast<array>(node);
|
||||
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) {
|
||||
return py::cast(it->second);
|
||||
}
|
||||
return py::cast(arr);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<array> tree_flatten(py::object tree, bool strict /* = true */) {
|
||||
std::vector<array> flat_tree;
|
||||
|
||||
tree_visit(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
} else if (strict) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
|
||||
return flat_tree;
|
||||
}
|
||||
|
||||
py::object tree_unflatten(
|
||||
py::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index /* = 0 */) {
|
||||
return tree_map(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
py::object structure_sentinel() {
|
||||
static py::object sentinel;
|
||||
|
||||
if (sentinel.ptr() == nullptr) {
|
||||
sentinel = py::capsule(&sentinel);
|
||||
// probably not needed but this should make certain that we won't ever
|
||||
// delete the sentinel
|
||||
sentinel.inc_ref();
|
||||
}
|
||||
|
||||
return sentinel;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||
py::object tree,
|
||||
bool strict /* = true */) {
|
||||
auto sentinel = structure_sentinel();
|
||||
std::vector<array> flat_tree;
|
||||
auto structure = tree_map(
|
||||
tree,
|
||||
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
return sentinel;
|
||||
} else if (!strict) {
|
||||
return py::cast<py::object>(obj);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[tree_flatten] The argument should contain only arrays");
|
||||
}
|
||||
});
|
||||
|
||||
return {flat_tree, structure};
|
||||
}
|
||||
|
||||
py::object tree_unflatten_from_structure(
|
||||
py::object structure,
|
||||
const std::vector<array>& values,
|
||||
int index /* = 0 */) {
|
||||
auto sentinel = structure_sentinel();
|
||||
return tree_map(structure, [&](py::handle obj) {
|
||||
if (obj.is(sentinel)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
60
python/src/trees.h
Normal file
60
python/src/trees.h
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlx::core;
|
||||
|
||||
void tree_visit(py::object tree, std::function<void(py::handle)> visitor);
|
||||
|
||||
py::object tree_map(
|
||||
const std::vector<py::object>& trees,
|
||||
std::function<py::object(const std::vector<py::object>&)> transform);
|
||||
|
||||
py::object tree_map(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> transform);
|
||||
|
||||
void tree_visit_update(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> visitor);
|
||||
|
||||
/**
|
||||
* Fill a pytree (recursive dict or list of dict or list) in place with the
|
||||
* given arrays. */
|
||||
void tree_fill(py::object& tree, const std::vector<array>& values);
|
||||
|
||||
/**
|
||||
* Replace all the arrays from the src values with the dst values in the
|
||||
* tree.
|
||||
*/
|
||||
void tree_replace(
|
||||
py::object& tree,
|
||||
const std::vector<array>& src,
|
||||
const std::vector<array>& dst);
|
||||
|
||||
/**
|
||||
* Flatten a tree into a vector of arrays. If strict is true, then the
|
||||
* function will throw if the tree contains a leaf which is not an array.
|
||||
*/
|
||||
std::vector<array> tree_flatten(py::object tree, bool strict = true);
|
||||
|
||||
/**
|
||||
* Unflatten a tree from a vector of arrays.
|
||||
*/
|
||||
py::object tree_unflatten(
|
||||
py::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index = 0);
|
||||
|
||||
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
|
||||
py::object tree,
|
||||
bool strict = true);
|
||||
|
||||
py::object tree_unflatten_from_structure(
|
||||
py::object structure,
|
||||
const std::vector<array>& values,
|
||||
int index = 0);
|
@ -1,6 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import operator
|
||||
import pickle
|
||||
import unittest
|
||||
import weakref
|
||||
from itertools import permutations
|
||||
@ -1440,6 +1441,15 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
b @= a
|
||||
self.assertTrue(mx.array_equal(a, b))
|
||||
|
||||
def test_load_from_pickled_np(self):
|
||||
a = np.array([1, 2, 3], dtype=np.int32)
|
||||
b = pickle.loads(pickle.dumps(a))
|
||||
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||
|
||||
a = np.array([1.0, 2.0, 3.0], dtype=np.float16)
|
||||
b = pickle.loads(pickle.dumps(a))
|
||||
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -415,6 +415,14 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
_, vjps = mx.vjp(func, (arr,), (cotan,))
|
||||
self.assertEqual(vjps[0].item(), 8.0)
|
||||
|
||||
def test_power_grad(self):
|
||||
def fun(x, y):
|
||||
res = x - y
|
||||
return res**x
|
||||
|
||||
grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))
|
||||
self.assertEqual(grad.item(), 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -539,6 +539,72 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
z = fun(mx.array(1), "two")
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
# Test nested constant
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
if y[0][0] == 1:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
z = fun(mx.array(1), [[1]])
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(mx.array(1), [[0]])
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
@partial(mx.compile)
|
||||
def fun(x, a, b):
|
||||
for ai in a:
|
||||
for bi in b:
|
||||
x = bi * x + ai
|
||||
return x
|
||||
|
||||
z = fun(mx.array(1), [1, 1], [2])
|
||||
self.assertEqual(z.item(), 7)
|
||||
|
||||
z = fun(mx.array(1), [1], [1, 2])
|
||||
self.assertEqual(z.item(), 5)
|
||||
|
||||
counter = [0]
|
||||
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
counter[0] += 1
|
||||
return x + y
|
||||
|
||||
z = fun(mx.array(1), 1)
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(1, mx.array(1))
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
self.assertEqual(counter[0], 2)
|
||||
|
||||
def test_compile_inf(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
return mx.isinf(x + 2)
|
||||
|
||||
out = fun(mx.array([0.0]))
|
||||
self.assertEqual(out.item(), False)
|
||||
|
||||
def test_unsupported_input_types(self):
|
||||
|
||||
class MyClass:
|
||||
value = 1
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
return x + y.value
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = fun(mx.array(0.0), MyClass())
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = fun(mx.array(0.0), y=MyClass())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import unittest
|
||||
@ -388,13 +388,8 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[
|
||||
in_mx,
|
||||
wt_mx,
|
||||
],
|
||||
[
|
||||
ct_mx,
|
||||
],
|
||||
[in_mx, wt_mx],
|
||||
[ct_mx],
|
||||
)
|
||||
pt_grad_in = F.grad.conv1d_input(
|
||||
in_pt.shape,
|
||||
@ -428,18 +423,218 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
|
||||
for idim, kdim, stride, padding, dilation in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
|
||||
((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
|
||||
((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
|
||||
):
|
||||
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
run_conv2D_grad(
|
||||
N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
|
||||
)
|
||||
|
||||
def __conv_general_test(
|
||||
self,
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride=1,
|
||||
padding=0,
|
||||
kernel_dilation=1,
|
||||
input_dilation=1,
|
||||
groups=1,
|
||||
flip=False,
|
||||
np_dtype=np.float32,
|
||||
atol=1e-5,
|
||||
):
|
||||
|
||||
with self.subTest(
|
||||
in_shape=in_shape,
|
||||
wt_shape=wt_shape,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
kernel_dilation=kernel_dilation,
|
||||
input_dilation=input_dilation,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
np_dtype=np_dtype,
|
||||
):
|
||||
|
||||
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
|
||||
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(np.moveaxis(x, -1, 1)).to("cpu"),
|
||||
(in_np, wt_np),
|
||||
)
|
||||
|
||||
out_mx = mx.conv_general(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
kernel_dilation=kernel_dilation,
|
||||
input_dilation=input_dilation,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
def conv_general_pt(
|
||||
inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
|
||||
):
|
||||
|
||||
C = inp.size()[1]
|
||||
ndim = inp.ndim - 2
|
||||
map_ints = lambda x: [x] * ndim if isinstance(x, int) else x
|
||||
|
||||
stride, padding, kernel_dilation, input_dilation = map(
|
||||
map_ints, (stride, padding, kernel_dilation, input_dilation)
|
||||
)
|
||||
|
||||
torch_convt_list = (
|
||||
F.conv_transpose1d,
|
||||
F.conv_transpose2d,
|
||||
F.conv_transpose3d,
|
||||
)
|
||||
torch_conv_list = (F.conv1d, F.conv2d, F.conv3d)
|
||||
|
||||
conv_f = torch_conv_list[ndim - 1]
|
||||
convt_f = torch_convt_list[ndim - 1]
|
||||
|
||||
if flip:
|
||||
wt = torch.flip(wt, tuple(np.arange(2, wt.ndim)))
|
||||
|
||||
if not np.all(input_dilation == 1):
|
||||
ones = torch.ones(
|
||||
[C]
|
||||
+ [
|
||||
1,
|
||||
]
|
||||
* (ndim + 1)
|
||||
).to(inp.dtype)
|
||||
inp = convt_f(inp, ones, stride=input_dilation, groups=C)
|
||||
|
||||
return conv_f(
|
||||
inp,
|
||||
wt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=kernel_dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
out_pt = conv_general_pt(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
kernel_dilation=kernel_dilation,
|
||||
input_dilation=input_dilation,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
out_pt = np.moveaxis(out_pt.numpy(), 1, -1)
|
||||
|
||||
self.assertEqual(out_mx.shape, out_pt.shape)
|
||||
self.assertTrue(np.allclose(out_mx, out_pt, atol=atol))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_general(self):
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 5, 16)
|
||||
stride = (1, 1)
|
||||
padding = (2, 2)
|
||||
kernel_dilation = (2, 3)
|
||||
input_dilation = (1, 1)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 10, 16)
|
||||
stride = (2, 3)
|
||||
padding = (0, 0)
|
||||
kernel_dilation = (3, 2)
|
||||
input_dilation = (2, 4)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 10, 16)
|
||||
stride = (2, 2)
|
||||
padding = (3, 2)
|
||||
kernel_dilation = (3, 2)
|
||||
input_dilation = (2, 4)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 10, 16)
|
||||
stride = (2, 3)
|
||||
padding = (3, 2)
|
||||
kernel_dilation = (3, 2)
|
||||
input_dilation = (2, 5)
|
||||
flip = False
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
in_shape = (2, 32, 32, 16)
|
||||
wt_shape = (32, 5, 5, 16)
|
||||
stride = (2, 3)
|
||||
padding = (0, 0)
|
||||
kernel_dilation = (3, 1)
|
||||
input_dilation = (2, 5)
|
||||
flip = True
|
||||
|
||||
self.__conv_general_test(
|
||||
in_shape,
|
||||
wt_shape,
|
||||
stride,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip=flip,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -66,13 +66,15 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
def test_save_and_load_safetensors(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
test_file = os.path.join(self.test_dir, "test.safetensors")
|
||||
with self.assertRaises(Exception):
|
||||
mx.save_safetensors("test", {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
|
||||
|
||||
mx.save_safetensors(
|
||||
"test", {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||
test_file, {"test": mx.ones((2, 2))}, {"testing": "test", "format": "mlx"}
|
||||
)
|
||||
res = mx.load("test.safetensors", return_metadata=True)
|
||||
res = mx.load(test_file, return_metadata=True)
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertEqual(res[1], {"testing": "test", "format": "mlx"})
|
||||
|
||||
|
45
python/tests/test_metal.py
Normal file
45
python/tests/test_metal.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestMetal(mlx_tests.MLXTestCase):
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_memory_info(self):
|
||||
old_limit = mx.metal.set_cache_limit(0)
|
||||
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
del a
|
||||
self.assertEqual(mx.metal.get_cache_memory(), 0)
|
||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), 0)
|
||||
self.assertEqual(mx.metal.set_cache_limit(old_limit), old_limit)
|
||||
|
||||
old_limit = mx.metal.set_memory_limit(10)
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), 10)
|
||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
|
||||
|
||||
# Query active and peak memory
|
||||
a = mx.zeros((4096,))
|
||||
mx.eval(a)
|
||||
active_mem = mx.metal.get_active_memory()
|
||||
self.assertTrue(active_mem >= 4096 * 4)
|
||||
|
||||
b = mx.zeros((4096,))
|
||||
mx.eval(b)
|
||||
del b
|
||||
|
||||
new_active_mem = mx.metal.get_active_memory()
|
||||
self.assertEqual(new_active_mem, active_mem)
|
||||
peak_mem = mx.metal.get_peak_memory()
|
||||
self.assertTrue(peak_mem >= 4096 * 8)
|
||||
cache_mem = mx.metal.get_cache_memory()
|
||||
self.assertTrue(cache_mem >= 4096 * 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
@ -8,7 +8,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
|
||||
class TestBase(mlx_tests.MLXTestCase):
|
||||
@ -665,7 +665,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
y_hat1 = nn.gelu_approx(x)
|
||||
y_hat2 = nn.gelu_fast_approx(x)
|
||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.025)
|
||||
|
||||
def test_sin_pe(self):
|
||||
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
|
||||
@ -905,6 +905,228 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
def test_upsample(self):
|
||||
b, h, w, c = 1, 2, 2, 1
|
||||
scale_factor = 2
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=True
|
||||
)
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear_no_align_corners = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=False
|
||||
)
|
||||
upsample_nearest_no_align_corners = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=False
|
||||
)
|
||||
# Test single feature map, align corners
|
||||
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
|
||||
expected_nearest = mx.array(
|
||||
[[[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]]]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_bilinear = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0.333333, 0.666667, 1],
|
||||
[0.666667, 1, 1.33333, 1.66667],
|
||||
[1.33333, 1.66667, 2, 2.33333],
|
||||
[2, 2.33333, 2.66667, 3],
|
||||
]
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
# Test single feature map, no align corners
|
||||
x = (
|
||||
mx.arange(1, b * h * w * c + 1)
|
||||
.reshape((b, c, h, w))
|
||||
.transpose((0, 2, 3, 1))
|
||||
)
|
||||
expected_bilinear_no_align_corners = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1.0000, 1.2500, 1.7500, 2.0000],
|
||||
[1.5000, 1.7500, 2.2500, 2.5000],
|
||||
[2.5000, 2.7500, 3.2500, 3.5000],
|
||||
[3.0000, 3.2500, 3.7500, 4.0000],
|
||||
]
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_nearest_no_align_corners = mx.array(
|
||||
[[[[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]]]
|
||||
).transpose((0, 2, 3, 1))
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
upsample_nearest_no_align_corners(x), expected_nearest_no_align_corners
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
upsample_bilinear_no_align_corners(x),
|
||||
expected_bilinear_no_align_corners,
|
||||
)
|
||||
)
|
||||
|
||||
# Test a more complex batch
|
||||
b, h, w, c = 2, 3, 3, 2
|
||||
scale_factor = 2
|
||||
x = mx.arange((b * h * w * c)).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
|
||||
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=True
|
||||
)
|
||||
|
||||
expected_nearest = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
|
||||
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
|
||||
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
|
||||
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
|
||||
],
|
||||
[
|
||||
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
|
||||
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
|
||||
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
|
||||
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
|
||||
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
|
||||
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
|
||||
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
|
||||
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
|
||||
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
|
||||
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
|
||||
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
|
||||
],
|
||||
[
|
||||
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
|
||||
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
|
||||
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
|
||||
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
|
||||
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
|
||||
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
|
||||
],
|
||||
],
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_bilinear = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0, 0.4, 0.8, 1.2, 1.6, 2.0],
|
||||
[1.2, 1.6, 2.0, 2.4, 2.8, 3.2],
|
||||
[2.4, 2.8, 3.2, 3.6, 4.0, 4.4],
|
||||
[3.6, 4.0, 4.4, 4.8, 5.2, 5.6],
|
||||
[4.8, 5.2, 5.6, 6.0, 6.4, 6.8],
|
||||
[6.0, 6.4, 6.8, 7.2, 7.6, 8.0],
|
||||
],
|
||||
[
|
||||
[9.0, 9.4, 9.8, 10.2, 10.6, 11.0],
|
||||
[10.2, 10.6, 11.0, 11.4, 11.8, 12.2],
|
||||
[11.4, 11.8, 12.2, 12.6, 13.0, 13.4],
|
||||
[12.6, 13.0, 13.4, 13.8, 14.2, 14.6],
|
||||
[13.8, 14.2, 14.6, 15.0, 15.4, 15.8],
|
||||
[15.0, 15.4, 15.8, 16.2, 16.6, 17.0],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[18.0, 18.4, 18.8, 19.2, 19.6, 20.0],
|
||||
[19.2, 19.6, 20.0, 20.4, 20.8, 21.2],
|
||||
[20.4, 20.8, 21.2, 21.6, 22.0, 22.4],
|
||||
[21.6, 22.0, 22.4, 22.8, 23.2, 23.6],
|
||||
[22.8, 23.2, 23.6, 24.0, 24.4, 24.8],
|
||||
[24.0, 24.4, 24.8, 25.2, 25.6, 26.0],
|
||||
],
|
||||
[
|
||||
[27.0, 27.4, 27.8, 28.2, 28.6, 29.0],
|
||||
[28.2, 28.6, 29.0, 29.4, 29.8, 30.2],
|
||||
[29.4, 29.8, 30.2, 30.6, 31.0, 31.4],
|
||||
[30.6, 31.0, 31.4, 31.8, 32.2, 32.6],
|
||||
[31.8, 32.2, 32.6, 33.0, 33.4, 33.8],
|
||||
[33.0, 33.4, 33.8, 34.2, 34.6, 35.0],
|
||||
],
|
||||
],
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
|
||||
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
|
||||
|
||||
# Test different height and width scale_factor
|
||||
b, h, w, c = 1, 2, 2, 2
|
||||
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=(2, 3), mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear = nn.Upsample(
|
||||
scale_factor=(2, 3), mode="linear", align_corners=True
|
||||
)
|
||||
|
||||
expected_nearest = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[2, 2, 2, 3, 3, 3],
|
||||
[2, 2, 2, 3, 3, 3],
|
||||
],
|
||||
[
|
||||
[4, 4, 4, 5, 5, 5],
|
||||
[4, 4, 4, 5, 5, 5],
|
||||
[6, 6, 6, 7, 7, 7],
|
||||
[6, 6, 6, 7, 7, 7],
|
||||
],
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_bilinear = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0.2, 0.4, 0.6, 0.8, 1],
|
||||
[0.666667, 0.866667, 1.06667, 1.26667, 1.46667, 1.66667],
|
||||
[1.33333, 1.53333, 1.73333, 1.93333, 2.13333, 2.33333],
|
||||
[2, 2.2, 2.4, 2.6, 2.8, 3],
|
||||
],
|
||||
[
|
||||
[4, 4.2, 4.4, 4.6, 4.8, 5],
|
||||
[4.66667, 4.86667, 5.06667, 5.26667, 5.46667, 5.66667],
|
||||
[5.33333, 5.53333, 5.73333, 5.93333, 6.13333, 6.33333],
|
||||
[6, 6.2, 6.4, 6.6, 6.8, 7],
|
||||
],
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
|
||||
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
|
||||
|
||||
# Test repr
|
||||
self.assertEqual(
|
||||
str(nn.Upsample(scale_factor=2)),
|
||||
"Upsample(scale_factor=2.0, mode='nearest', align_corners=False)",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(nn.Upsample(scale_factor=(2, 3))),
|
||||
"Upsample(scale_factor=(2.0, 3.0), mode='nearest', align_corners=False)",
|
||||
)
|
||||
|
||||
def test_pooling(self):
|
||||
# Test 1d pooling
|
||||
x = mx.array(
|
||||
|
@ -1047,6 +1047,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a = mx.arange(0, float("inf"), float("inf"))
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.arange(float("inf"), 1, float("inf"))
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.arange(float("inf"), 1, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
INT_MAX = 2147483647
|
||||
a = mx.arange(0, INT_MAX + 1, 1)
|
||||
|
||||
a = mx.arange(5)
|
||||
expected = [0, 1, 2, 3, 4]
|
||||
@ -1132,6 +1137,27 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
a = mx.arange(0, 10, 100)
|
||||
expected = [0]
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
a = mx.arange(10, 0, 1)
|
||||
expected = []
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
a = mx.arange(10, 0, float("inf"))
|
||||
expected = []
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
a = mx.arange(0, 10, float("inf"))
|
||||
expected = [0]
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
a = mx.arange(0, -10, float("-inf"))
|
||||
expected = [0]
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
def test_unary_ops(self):
|
||||
def test_ops(npop, mlxop, x, y, atol):
|
||||
r_np = npop(x)
|
||||
@ -1563,7 +1589,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
shape = (3, 4, 5)
|
||||
for dtype in ("int32", "float32"):
|
||||
for axis in (None, 0, 1, 2):
|
||||
for kth in (-2, 2):
|
||||
for kth in (-2, 0, 2):
|
||||
with self.subTest(dtype=dtype, axis=axis, kth=kth):
|
||||
np.random.seed(0)
|
||||
np_dtype = getattr(np, dtype)
|
||||
@ -1579,13 +1605,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(c_np, c_mx))
|
||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||
|
||||
top_k_mx = mx.topk(a_mx, kth, axis=axis)
|
||||
self.assertTrue(np.all(c_np <= top_k_mx))
|
||||
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
|
||||
|
||||
if kth >= 0:
|
||||
d_np = np.take(b_mx, np.arange(kth), axis=axis)
|
||||
self.assertTrue(np.all(d_np <= c_mx))
|
||||
top_k_mx = mx.topk(a_mx, kth, axis=axis)
|
||||
top_k_np = np.take(
|
||||
np.partition(a_np, -kth, axis=axis), (-kth,), axis=axis
|
||||
)
|
||||
self.assertTrue(np.all(top_k_np <= top_k_mx))
|
||||
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
|
||||
N = a_mx.shape[axis] if axis is not None else a_mx.size
|
||||
M = top_k_mx.shape[axis or 0]
|
||||
self.assertEqual(M, (kth + N) % N)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.getenv("LOW_MEMORY", None) is not None,
|
||||
@ -1906,12 +1935,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_1d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_1d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_1d(mx.array(array))
|
||||
np_res = np.atleast_1d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_atleast_2d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
@ -1936,12 +1969,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_2d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_2d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_2d(mx.array(array))
|
||||
np_res = np.atleast_2d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
def test_atleast_3d(self):
|
||||
def compare_nested_lists(x, y):
|
||||
@ -1966,12 +2003,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
[[[[1]], [[2]], [[3]]]],
|
||||
]
|
||||
|
||||
for array in arrays:
|
||||
mx_arrays = [mx.atleast_3d(mx.array(x)) for x in arrays]
|
||||
atleast_arrays = mx.atleast_3d(*mx_arrays)
|
||||
|
||||
for i, array in enumerate(arrays):
|
||||
mx_res = mx.atleast_3d(mx.array(array))
|
||||
np_res = np.atleast_3d(np.array(array))
|
||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
||||
self.assertEqual(mx_res.shape, np_res.shape)
|
||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -328,6 +328,37 @@ class TestSchedulers(unittest.TestCase):
|
||||
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_schedule_joiner(self):
|
||||
boundaries = [2, 3, 4]
|
||||
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
||||
with self.assertRaises(ValueError):
|
||||
opt.schedulers.join_schedules(schedules, boundaries)
|
||||
boundaries = [2, 4]
|
||||
schedule = opt.schedulers.join_schedules(schedules, boundaries)
|
||||
self.assertEqual(schedule(0).item(), 3)
|
||||
self.assertEqual(schedule(1).item(), 3)
|
||||
self.assertEqual(schedule(2).item(), 4)
|
||||
self.assertEqual(schedule(3).item(), 4)
|
||||
self.assertEqual(schedule(5).item(), 5)
|
||||
self.assertEqual(schedule(7).item(), 5)
|
||||
|
||||
def test_linear_warmup_with_cosine_decay(self):
|
||||
warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)
|
||||
cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)
|
||||
cos_with_warmup = opt.schedulers.join_schedules(
|
||||
[warmup_schedule, cosine_schedule], [101]
|
||||
)
|
||||
self.assertEqual(cos_with_warmup(0), 0.0)
|
||||
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
|
||||
optimizer = opt.Adam(learning_rate=cos_with_warmup)
|
||||
for _ in range(100):
|
||||
optimizer.update({}, {})
|
||||
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
|
||||
for _ in range(100):
|
||||
optimizer.update({}, {})
|
||||
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
|
||||
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
|
||||
|
||||
def test_compile_with_schedule(self):
|
||||
lr_schedule = opt.exponential_decay(1e-1, 0.9)
|
||||
optimizer = opt.SGD(learning_rate=lr_schedule)
|
||||
|
2
setup.py
2
setup.py
@ -152,7 +152,7 @@ if __name__ == "__main__":
|
||||
|
||||
setup(
|
||||
name="mlx",
|
||||
version=get_version("0.3.0"),
|
||||
version=get_version("0.5.0"),
|
||||
author="MLX Contributors",
|
||||
author_email="mlx@group.apple.com",
|
||||
description="A framework for machine learning on Apple silicon.",
|
||||
|
@ -1075,6 +1075,37 @@ TEST_CASE("test jvp from vjp") {
|
||||
CHECK(compute_derivs(subtract));
|
||||
CHECK(compute_derivs(power));
|
||||
}
|
||||
|
||||
// Conditional selection element-wise op
|
||||
{
|
||||
auto condition = random::randint(0, 2, {5, 10});
|
||||
auto x = random::uniform({5, 10});
|
||||
auto y = random::uniform({5, 10});
|
||||
eval(condition, x, y);
|
||||
|
||||
auto compute_derivs = [&condition, &x, &y](auto fn) {
|
||||
auto fn_wrap = [&fn](std::vector<array> inputs) {
|
||||
return std::vector<array>{
|
||||
fn(inputs[0], inputs[1], inputs[2], default_device())};
|
||||
};
|
||||
|
||||
// Compute vjp and add results
|
||||
auto vjps = vjp(fn_wrap, {condition, x, y}, {ones(x.shape())}).second;
|
||||
auto vjp_out = add(add(vjps[0], vjps[1]), vjps[2]);
|
||||
|
||||
// Compute jvp
|
||||
array jvp_out =
|
||||
jvp(fn_wrap,
|
||||
{condition, x, y},
|
||||
{ones(condition.shape()), ones(y.shape()), ones(x.shape())})
|
||||
.second[0];
|
||||
|
||||
array result = array_equal(vjp_out, jvp_out);
|
||||
return result.item<bool>();
|
||||
};
|
||||
|
||||
CHECK(compute_derivs(where));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test complex gradients") {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <array>
|
||||
#include "doctest/doctest.h"
|
||||
@ -473,41 +473,42 @@ TEST_CASE("test metal validation") {
|
||||
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
|
||||
}
|
||||
|
||||
TEST_CASE("test metal enable/disable cache") {
|
||||
// Test enable metal cache
|
||||
TEST_CASE("test metal memory info") {
|
||||
// Test cache limits
|
||||
{
|
||||
metal::set_cache_enabled(true);
|
||||
CHECK(metal::cache_enabled());
|
||||
|
||||
auto& a = metal::allocator();
|
||||
auto size = 100;
|
||||
auto buf = a.malloc(size, false);
|
||||
|
||||
// Release a
|
||||
a.free(buf);
|
||||
|
||||
// Check size should equals to size
|
||||
CHECK_EQ(static_cast<MTL::Buffer*>(buf.ptr())->length(), size);
|
||||
auto old_limit = metal::set_cache_limit(0);
|
||||
{
|
||||
auto a = zeros({4096});
|
||||
eval(a);
|
||||
}
|
||||
CHECK_EQ(metal::get_cache_memory(), 0);
|
||||
CHECK_EQ(metal::set_cache_limit(old_limit), 0);
|
||||
CHECK_EQ(metal::set_cache_limit(old_limit), old_limit);
|
||||
}
|
||||
|
||||
// Test disable metal cache
|
||||
// Test memory limits
|
||||
{
|
||||
metal::set_cache_enabled(false);
|
||||
CHECK(!metal::cache_enabled());
|
||||
auto old_limit = metal::set_memory_limit(10);
|
||||
CHECK_EQ(metal::set_memory_limit(old_limit), 10);
|
||||
CHECK_EQ(metal::set_memory_limit(old_limit), old_limit);
|
||||
}
|
||||
|
||||
auto& a = metal::allocator();
|
||||
auto size = 100;
|
||||
auto buf = a.malloc(size, false);
|
||||
auto buf_ptr = static_cast<MTL::Buffer*>(buf.ptr());
|
||||
unsigned char first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
|
||||
// Query active and peak memory
|
||||
{
|
||||
auto a = zeros({4096});
|
||||
eval(a);
|
||||
auto active_mem = metal::get_active_memory();
|
||||
CHECK(active_mem >= 4096 * 4);
|
||||
{
|
||||
auto b = zeros({4096});
|
||||
eval(b);
|
||||
}
|
||||
auto new_active_mem = metal::get_active_memory();
|
||||
CHECK_EQ(new_active_mem, active_mem);
|
||||
auto peak_mem = metal::get_peak_memory();
|
||||
CHECK(peak_mem >= 4096 * 8);
|
||||
|
||||
// Release a
|
||||
a.free(buf);
|
||||
|
||||
// If release successfully, the first byte should be different from the
|
||||
// first byte before release
|
||||
unsigned char new_first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
|
||||
|
||||
CHECK_NE(new_first_byte, first_byte);
|
||||
auto cache_mem = metal::get_cache_memory();
|
||||
CHECK(cache_mem >= 4096 * 4);
|
||||
}
|
||||
}
|
||||
|
@ -791,13 +791,13 @@ TEST_CASE("test reduction ops") {
|
||||
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||
|
||||
x = array({-inf, -inf});
|
||||
WARN_EQ(logsumexp(x).item<float>(), -inf);
|
||||
CHECK_EQ(logsumexp(x).item<float>(), -inf);
|
||||
|
||||
x = array({0.0f, -inf});
|
||||
CHECK_EQ(logsumexp(x).item<float>(), 0.0f);
|
||||
|
||||
x = array({0.0f, inf});
|
||||
WARN_EQ(logsumexp(x).item<float>(), inf);
|
||||
CHECK_EQ(logsumexp(x).item<float>(), inf);
|
||||
|
||||
x = reshape(arange(6, float32), {2, 3});
|
||||
|
||||
@ -1858,6 +1858,14 @@ TEST_CASE("test scatter") {
|
||||
out = scatter(in, inds, updates, 0);
|
||||
CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item<bool>());
|
||||
|
||||
// Array scatters with col contiguous updates
|
||||
in = zeros({4, 4}, float32);
|
||||
inds = array({0, 1, 2, 3});
|
||||
updates = transpose(reshape(arange(16, float32), {4, 1, 4}));
|
||||
out = scatter(in, inds, updates, 0);
|
||||
CHECK(array_equal(out, transpose(reshape(arange(16, float32), {4, 4})))
|
||||
.item<bool>());
|
||||
|
||||
// Irregular strided index and reduce collision test
|
||||
in = zeros({10}, float32);
|
||||
inds = broadcast_to(array(3), {10});
|
||||
@ -1877,10 +1885,10 @@ TEST_CASE("test scatter") {
|
||||
|
||||
// Irregularly strided updates test
|
||||
in = ones({3, 3});
|
||||
updates = broadcast_to(array({0, 0, 0}), {1, 3, 3});
|
||||
updates = broadcast_to(array({2, 2, 2}), {1, 3, 3});
|
||||
inds = array({0});
|
||||
out = scatter(in, inds, updates, 0);
|
||||
CHECK(array_equal(out, zeros({3, 3})).item<bool>());
|
||||
CHECK(array_equal(out, ones({3, 3}) * 2).item<bool>());
|
||||
|
||||
// Along different axis
|
||||
in = zeros({2, 3});
|
||||
@ -2185,6 +2193,8 @@ TEST_CASE("test power") {
|
||||
}
|
||||
|
||||
TEST_CASE("test where") {
|
||||
const float inf = std::numeric_limits<float>::infinity();
|
||||
|
||||
array condition(true);
|
||||
array x(1.0f);
|
||||
array y(0.0f);
|
||||
@ -2216,6 +2226,49 @@ TEST_CASE("test where") {
|
||||
out = where(condition, x, y);
|
||||
expected = array({1, 2, 2, 1}, {2, 2});
|
||||
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
|
||||
|
||||
condition = array(true);
|
||||
x = array({1, 2, 3});
|
||||
y = array({3, 6, 13});
|
||||
CHECK(array_equal(where(condition, x, y), array({1, 2, 3})).item<bool>());
|
||||
|
||||
condition = array(false);
|
||||
x = array({1, 2, 3});
|
||||
y = array({3, 6, 13});
|
||||
CHECK(array_equal(where(condition, x, y), array({3, 6, 13})).item<bool>());
|
||||
|
||||
condition = array({1, 1, 0});
|
||||
x = array({1, 2, 3});
|
||||
y = array({11, 12, 13});
|
||||
CHECK(array_equal(where(condition, x, y), array({1, 2, 13})).item<bool>());
|
||||
|
||||
condition = array({true, false}, {2, 1, 1});
|
||||
x = array({1, 2, 3, 4}, {2, 1, 2});
|
||||
y = array({11, 22, 33, 44}, {2, 2, 1});
|
||||
expected = array({1, 2, 1, 2, 33, 33, 44, 44}, {2, 2, 2});
|
||||
CHECK(array_equal(where(condition, x, y), expected).item<bool>());
|
||||
|
||||
condition = array({true, false, false});
|
||||
x = array({inf, 2.0, 3.0});
|
||||
y = array({10.0, 20.0, -inf});
|
||||
CHECK(array_equal(where(condition, x, y), array({inf, 20.0, -inf}))
|
||||
.item<bool>());
|
||||
|
||||
// 4-dim optimized case.
|
||||
condition = array({false});
|
||||
x = array({1, 2}, {2, 1, 1, 1});
|
||||
y = array({3, 4}, {1, 1, 2, 1});
|
||||
CHECK(array_equal(where(condition, x, y), array({3, 4, 3, 4}, {2, 1, 2, 1}))
|
||||
.item<bool>());
|
||||
|
||||
// 5-dim optimized case.
|
||||
condition = array({true, false}, {2, 1, 1, 1, 1});
|
||||
x = array({1, 2, 3, 4}, {2, 1, 1, 1, 2});
|
||||
y = array({11, 22}, {1, 1, 2, 1, 1});
|
||||
CHECK(array_equal(
|
||||
where(condition, x, y),
|
||||
array({1, 2, 1, 2, 11, 11, 22, 22}, {2, 1, 2, 1, 2}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test stack") {
|
||||
@ -2734,6 +2787,19 @@ TEST_CASE("test atleast_1d") {
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_1d vector") {
|
||||
auto x = std::vector<array>{
|
||||
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
|
||||
auto out = atleast_1d(x);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].ndim(), 1);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{1});
|
||||
CHECK_EQ(out[1].ndim(), 1);
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{3});
|
||||
CHECK_EQ(out[2].ndim(), 2);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_2d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_2d(x);
|
||||
@ -2751,6 +2817,19 @@ TEST_CASE("test atleast_2d") {
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_2d vector") {
|
||||
auto x = std::vector<array>{
|
||||
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
|
||||
auto out = atleast_2d(x);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].ndim(), 2);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1});
|
||||
CHECK_EQ(out[1].ndim(), 2);
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3});
|
||||
CHECK_EQ(out[2].ndim(), 2);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_3d") {
|
||||
auto x = array(1);
|
||||
auto out = atleast_3d(x);
|
||||
@ -2766,4 +2845,36 @@ TEST_CASE("test atleast_3d") {
|
||||
out = atleast_3d(x);
|
||||
CHECK_EQ(out.ndim(), 3);
|
||||
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test atleast_3d vector") {
|
||||
auto x = std::vector<array>{
|
||||
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
|
||||
auto out = atleast_3d(x);
|
||||
CHECK_EQ(out.size(), 3);
|
||||
CHECK_EQ(out[0].ndim(), 3);
|
||||
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1, 1});
|
||||
CHECK_EQ(out[1].ndim(), 3);
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3, 1});
|
||||
CHECK_EQ(out[2].ndim(), 3);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1});
|
||||
}
|
||||
|
||||
TEST_CASE("test topk") {
|
||||
auto x = reshape(arange(10), {2, 5});
|
||||
|
||||
{
|
||||
auto y = topk(x, 1, 1);
|
||||
CHECK(array_equal(y, array({4, 9}, {2, 1})).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto y = topk(x, 2, 0);
|
||||
CHECK(array_equal(y, x).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto y = topk(x, 1, 0);
|
||||
CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
@ -248,11 +248,9 @@ TEST_CASE("test random uniform") {
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
if (is_available(float16)) {
|
||||
x = random::uniform({}, float16);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float16);
|
||||
}
|
||||
x = random::uniform({}, float16);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float16);
|
||||
|
||||
x = random::uniform({0});
|
||||
CHECK(array_equal(x, array({})).item<bool>());
|
||||
@ -467,11 +465,9 @@ TEST_CASE("test random bernoulli") {
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
|
||||
// Bernoulli parameter can have floating point type
|
||||
if (is_available(float16)) {
|
||||
x = random::bernoulli(array(0.5, float16));
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
}
|
||||
x = random::bernoulli(array(0.5, float16));
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
|
||||
CHECK_THROWS(random::bernoulli(array(1, int32)));
|
||||
|
||||
@ -513,11 +509,9 @@ TEST_CASE("Test truncated normal") {
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
if (is_available(float16)) {
|
||||
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float16);
|
||||
}
|
||||
x = random::truncated_normal(array(-2.0), array(2.0), {}, float16);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.dtype(), float16);
|
||||
|
||||
// Requested shape
|
||||
x = random::truncated_normal(array(-2.0), array(2.0), {3, 4});
|
||||
|
@ -138,6 +138,70 @@ TEST_CASE("test simple vmap") {
|
||||
CHECK(array_equal(out, x + y).item<bool>());
|
||||
}
|
||||
|
||||
// vmap where (ternary op)
|
||||
{
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
auto out = where(inputs[0], inputs[1], inputs[2]);
|
||||
return std::vector<array>{out};
|
||||
};
|
||||
|
||||
auto vfun = vmap(fun);
|
||||
array cond({true, false}, {2, 1});
|
||||
array x({1.0, 2.0}, {2, 1});
|
||||
array y({2.0, 4.0}, {2, 1});
|
||||
auto out = vfun({cond, x, y})[0];
|
||||
CHECK(array_equal(out, array({1.0, 4.0}, {2, 1})).item<bool>());
|
||||
|
||||
cond = array({true, true, false}, {1, 3});
|
||||
x = ones({2, 1, 3});
|
||||
y = zeros({3, 2});
|
||||
vfun = vmap(fun, {1, 2, 0});
|
||||
out = vfun({cond, x, y})[0];
|
||||
|
||||
CHECK(
|
||||
array_equal(out, array({1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0}, {3, 2, 2}))
|
||||
.item<bool>());
|
||||
|
||||
vfun = vmap(fun, {1, 2, 0}, {1});
|
||||
out = vfun({cond, x, y})[0];
|
||||
CHECK(
|
||||
array_equal(out, array({1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0}, {2, 3, 2}))
|
||||
.item<bool>());
|
||||
|
||||
cond = array({true, false});
|
||||
x = array(2.);
|
||||
y = ones({3, 2});
|
||||
vfun = vmap(fun, {-1, -1, 0});
|
||||
out = vfun({cond, x, y})[0];
|
||||
CHECK(array_equal(out, array({2, 1, 2, 1, 2, 1}, {3, 2})).item<bool>());
|
||||
|
||||
cond = array({true, false});
|
||||
x = ones({3, 2});
|
||||
y = array(2.);
|
||||
vfun = vmap(fun, {-1, 0, -1});
|
||||
out = vfun({cond, x, y})[0];
|
||||
CHECK(array_equal(out, array({1, 2, 1, 2, 1, 2}, {3, 2})).item<bool>());
|
||||
|
||||
CHECK_THROWS_AS(vmap(fun, {-1, -1, -1}, {0}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, {-1, 0, -1}, {-1}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, {-1, -1, 0}, {-1}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(vmap(fun, {0, -1, -1}, {-1}), std::invalid_argument);
|
||||
|
||||
cond = array({true, false});
|
||||
x = array(1.);
|
||||
y = array(2.);
|
||||
vfun = vmap(fun, {-1, -1, -1}, {-1});
|
||||
out = vfun({cond, x, y})[0];
|
||||
CHECK(array_equal(out, array({1.0, 2.0})).item<bool>());
|
||||
|
||||
cond = array({1, 1, 1, 0, 0, 0}, {3, 2, 1});
|
||||
x = ones({3, 2, 1});
|
||||
y = full({3, 2, 1}, 2);
|
||||
vfun = vmap(vmap(fun));
|
||||
out = vfun({cond, x, y})[0];
|
||||
CHECK(array_equal(out, array({1, 1, 1, 2, 2, 2}, {3, 2, 1})).item<bool>());
|
||||
}
|
||||
|
||||
// vmap with capturing closure
|
||||
{
|
||||
auto x = add(add(ones({2}), zeros({2})), zeros({2}));
|
||||
|
Loading…
Reference in New Issue
Block a user