diff --git a/.circleci/config.yml b/.circleci/config.yml index 250f35faa..94e4e909f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -237,6 +237,14 @@ workflows: jobs: - mac_build_and_test - linux_build_and_test + + build_pypi_release: + when: + and: + - not: << pipeline.parameters.nightly_build >> + - not: << pipeline.parameters.weekly_build >> + - not: << pipeline.parameters.test_release >> + jobs: - build_release: filters: tags: diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index e15aafd5b..dca91bf90 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -11,7 +11,7 @@ MLX was developed with contributions from the following individuals: - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. -- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``. +- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``. - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Luca Arnaboldi: Added `Ceil` and `Floor` ops. @@ -256,4 +256,4 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and -limitations under the License. \ No newline at end of file +limitations under the License. diff --git a/CMakeLists.txt b/CMakeLists.txt index 8b1a4cf52..c78c5d756 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) if(NOT MLX_VERSION) - set(MLX_VERSION 0.3.0) + set(MLX_VERSION 0.5.0) endif() # --------------------- Processor tests ------------------------- @@ -67,8 +67,6 @@ if (MLX_BUILD_METAL AND NOT METAL_LIB) set(MLX_BUILD_METAL OFF) elseif (MLX_BUILD_METAL) message(STATUS "Building METAL sources") - add_compile_definitions(_METAL_) - # Throw an error if xcrun not found execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" OUTPUT_VARIABLE MACOS_VERSION diff --git a/README.md b/README.md index 118cc828e..80114d23e 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,12 @@ brought to you by Apple machine learning research. Some key features of MLX include: - - **Familiar APIs**: MLX has a Python API that closely follows NumPy. - MLX also has a fully featured C++ API, which closely mirrors the Python API. - MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs - that closely follow PyTorch to simplify building more complex models. + - **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX + also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and + [Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror + the Python API. MLX has higher-level packages like `mlx.nn` and + `mlx.optimizers` with APIs that closely follow PyTorch to simplify building + more complex models. - **Composable function transformations**: MLX supports composable function transformations for automatic differentiation, automatic vectorization, diff --git a/benchmarks/cpp/single_ops.cpp b/benchmarks/cpp/single_ops.cpp index 69cba09e9..4505282f1 100644 --- a/benchmarks/cpp/single_ops.cpp +++ b/benchmarks/cpp/single_ops.cpp @@ -73,6 +73,7 @@ void time_unary_ops() { void time_binary_ops() { int M = 1000, N = 100, K = 10; + auto condition = random::randint(0, 2, {M, N, K}); auto a = random::uniform({M, N, K}); auto b = random::uniform({M, N, K}); auto device = default_device(); @@ -84,7 +85,9 @@ void time_binary_ops() { TIME(divide, a, b, device); TIME(maximum, a, b, device); TIME(minimum, a, b, device); + TIME(where, condition, a, b, device); + condition = array({true}); b = random::uniform({1}); eval(b); TIMEM("scalar", add, a, b, device); @@ -93,7 +96,9 @@ void time_binary_ops() { TIMEM("scalar", multiply, a, b, device); TIMEM("vector-scalar", divide, a, b, device); TIMEM("scalar-vector", divide, b, a, device); + TIMEM("scalar-vector", where, condition, a, b, device); + condition = broadcast_to(array({true}), {1000, 100}); a = broadcast_to(random::uniform({1}), {1000, 100}); b = broadcast_to(random::uniform({1}), {1000, 100}); eval(a, b); @@ -101,6 +106,7 @@ void time_binary_ops() { TIMEM("scalar-scalar broadcast", subtract, a, b, device); TIMEM("scalar-scalar broadcast", multiply, a, b, device); TIMEM("scalar-scalar broadcast", divide, a, b, device); + TIMEM("scalar-scalar broadcast", where, condition, a, b, device); } void time_strided_ops() { diff --git a/benchmarks/python/conv_bench.py b/benchmarks/python/conv_bench.py new file mode 100644 index 000000000..f052487d9 --- /dev/null +++ b/benchmarks/python/conv_bench.py @@ -0,0 +1,129 @@ +import argparse +import math +import os +import subprocess +import time + +import mlx.core as mx +import numpy as np +import torch + +device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) +device_name = device_name.decode("utf-8").strip("\n") + +N_warmup = 10 +N_iter_bench = 100 +N_iter_func = 5 + + +def bench(f, a, b): + for i in range(N_warmup): + f(a, b) + torch.mps.synchronize() + + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(a, b) + e = time.perf_counter_ns() + return (e - s) * 1e-9 + + +def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)): + def mx_conv_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = mx.conv2d(a, b, stride=strides, padding=padding) + ys.append(y) + mx.eval(ys) + return ys + + return mx_conv_2D + + +def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)): + @torch.no_grad() + def pt_conv_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = torch.conv2d(a, b, stride=strides, padding=padding) + ys.append(y) + torch.mps.synchronize() + return ys + + return pt_conv_2D + + +def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype): + + scale = 1.0 / math.sqrt(kH * kH * C) + a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) + b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps") + b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps") + + torch.mps.synchronize() + + f_mx = make_mx_conv_2D(strides, padding) + f_pt = make_pt_conv_2D(strides, padding) + + time_torch = bench(f_pt, a_pt, b_pt) + time_mlx = bench(f_mx, a_mx, b_mx) + + out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding) + out_pt = torch.conv2d( + a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)) + out_pt = out_pt.numpy(force=True) + + atol = 2e-5 if np_dtype == np.float32 else 1e-4 + + if not np.allclose(out_pt, out_mx, atol=atol): + print( + f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" + ) + + return time_mlx, time_torch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run conv benchmarks") + + dtypes = ("float32",) + shapes = ( + (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)), + (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)), + (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)), + (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)), + (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)), + (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)), + (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)), + (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)), + (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)), + (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)), + (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)), + (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)), + (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)), + (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)), + (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)), + (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)), + ) + + for dtype in dtypes: + print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%") + for N, H, W, C, kH, kW, O, strides, padding in shapes: + np_dtype = getattr(np, dtype) + time_mlx, time_torch = bench_shape( + N, H, W, C, kH, kW, O, strides, padding, np_dtype + ) + diff = time_torch / time_mlx - 1.0 + + print( + f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%" + ) + if time_mlx >= 2.0 * time_torch: + print("ATTENTION ^^^^^^^") diff --git a/benchmarks/python/scatter_bench.py b/benchmarks/python/scatter_bench.py index 2d63d8bf1..d2fd569ac 100644 --- a/benchmarks/python/scatter_bench.py +++ b/benchmarks/python/scatter_bench.py @@ -7,12 +7,14 @@ import torch from time_utils import measure_runtime -def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape): +def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes): def scatter(dst, x, idx): - dst[idx] = x + dst[*idx] = x mx.eval(dst) - idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape) + idx = [] + for idx_shape in idx_shapes: + idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape)) x = mx.random.normal(x_shape).astype(mx.float32) dst = mx.random.normal(dst_shape).astype(mx.float32) @@ -20,13 +22,15 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape): print(f"MLX: {runtime:.3f}ms") -def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device): +def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device): def gather(dst, x, idx, device): - dst[idx] = x + dst[*idx] = x if device == torch.device("mps"): torch.mps.synchronize() - idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device) + idx = [] + for idx_shape in idx_shapes: + idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)) x = torch.randn(x_shape, dtype=torch.float32).to(device) dst = torch.randn(dst_shape, dtype=torch.float32).to(device) @@ -45,9 +49,45 @@ if __name__ == "__main__": else: device = torch.device("mps") - dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)] - idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)] - x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)] + dst_shapes = [ + (10, 64), + (100_000, 64), + (1_000_000, 64), + (100_000,), + (2_000_00,), + (20_000_000,), + (10000, 64), + (100, 64), + (100, 10_000, 64), + (10, 100, 100, 21), + (1_000, 1_000, 10), + ] + idx_shapes = [ + [(1_000_000,)], + [(1_000_000,)], + [(100_000,)], + [(1_000_000,)], + [(20_000_000,)], + [(20_000_000,)], + [(1000000,)], + [(10000000,)], + [(1_000,)], + [(10_000,)], + [(1_000,), (1_000,)], + ] + x_shapes = [ + (1_000_000, 64), + (1_000_000, 64), + (100_000, 64), + (1_000_000,), + (20_000_000,), + (20_000_000,), + (1000000, 64), + (10000000, 64), + (1_000, 10_000, 64), + (10_000, 100, 100, 21), + (1_000, 10), + ] for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes): print("=" * 20) diff --git a/docs/src/_static/mlx_logo.png b/docs/src/_static/mlx_logo.png index 49400bd8d..be122bf7c 100644 Binary files a/docs/src/_static/mlx_logo.png and b/docs/src/_static/mlx_logo.png differ diff --git a/docs/src/_static/mlx_logo_dark.png b/docs/src/_static/mlx_logo_dark.png new file mode 100644 index 000000000..cda3c1f61 Binary files /dev/null and b/docs/src/_static/mlx_logo_dark.png differ diff --git a/docs/src/conf.py b/docs/src/conf.py index 0654cf53c..603bfa847 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -49,10 +49,12 @@ html_theme_options = { "repository_url": "https://github.com/ml-explore/mlx", "use_repository_button": True, "navigation_with_keys": False, + "logo": { + "image_light": "_static/mlx_logo.png", + "image_dark": "_static/mlx_logo_dark.png", + }, } -html_logo = "_static/mlx_logo.png" - # -- Options for HTMLHelp output --------------------------------------------- diff --git a/docs/src/index.rst b/docs/src/index.rst index 50dfe9083..e54a55b7a 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -64,6 +64,7 @@ are the CPU and GPU. python/transforms python/fft python/linalg + python/metal python/nn python/optimizers python/tree_utils diff --git a/docs/src/python/metal.rst b/docs/src/python/metal.rst new file mode 100644 index 000000000..c11deb4fa --- /dev/null +++ b/docs/src/python/metal.rst @@ -0,0 +1,14 @@ +Metal +===== + +.. currentmodule:: mlx.core.metal + +.. autosummary:: + :toctree: _autosummary + + is_available + get_active_memory + get_peak_memory + get_cache_memory + set_memory_limit + set_cache_limit diff --git a/docs/src/python/nn/functions.rst b/docs/src/python/nn/functions.rst index fc99dcad1..db276afdf 100644 --- a/docs/src/python/nn/functions.rst +++ b/docs/src/python/nn/functions.rst @@ -12,13 +12,24 @@ simple functions. :toctree: _autosummary_functions :template: nn-module-template.rst + elu gelu gelu_approx gelu_fast_approx + glu + hardswish + leaky_relu + log_sigmoid + log_softmax mish prelu relu + relu6 selu - softshrink + sigmoid silu + softmax + softplus + softshrink step + tanh diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 0f5fca9db..f6755e8fe 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -40,3 +40,4 @@ Layers Softshrink Step Transformer + Upsample \ No newline at end of file diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 7ec7defc9..6396bb3c6 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -35,6 +35,7 @@ Operations convolve conv1d conv2d + conv_general cos cosh dequantize @@ -56,6 +57,7 @@ Operations greater_equal identity inner + isclose isnan isposinf isneginf @@ -120,6 +122,8 @@ Operations tan tanh tensordot + tile + topk transpose tri tril diff --git a/docs/src/python/optimizers/schedulers.rst b/docs/src/python/optimizers/schedulers.rst index a83883ddb..50855e1e7 100644 --- a/docs/src/python/optimizers/schedulers.rst +++ b/docs/src/python/optimizers/schedulers.rst @@ -8,6 +8,8 @@ Schedulers .. autosummary:: :toctree: _autosummary - step_decay - exponential_decay cosine_decay + exponential_decay + join_schedules + linear_schedule + step_decay diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index e147b5888..1d4258f62 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -64,6 +64,7 @@ DEFAULT(Reshape) DEFAULT(Remainder) DEFAULT(Round) DEFAULT(Scatter) +DEFAULT(Select) DEFAULT(Sigmoid) DEFAULT(Sign) DEFAULT(Slice) diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 38a9819e5..22b6ea6f5 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -1,6 +1,9 @@ if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(COMPILER ${CMAKE_C_COMPILER}) set(CLANG TRUE) +else() + set(COMPILER ${CMAKE_CXX_COMPILER}) endif() add_custom_command( @@ -8,16 +11,16 @@ add_custom_command( COMMAND /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp - ${CMAKE_CXX_COMPILER} - ${CMAKE_SOURCE_DIR} + ${COMPILER} + ${PROJECT_SOURCE_DIR} ${CLANG} DEPENDS make_compiled_preamble.sh compiled_preamble.h - ${CMAKE_SOURCE_DIR}/mlx/types/half_types.h - ${CMAKE_SOURCE_DIR}/mlx/types/fp16.h - ${CMAKE_SOURCE_DIR}/mlx/types/bf16.h - ${CMAKE_SOURCE_DIR}/mlx/types/complex.h + ${PROJECT_SOURCE_DIR}/mlx/types/half_types.h + ${PROJECT_SOURCE_DIR}/mlx/types/fp16.h + ${PROJECT_SOURCE_DIR}/mlx/types/bf16.h + ${PROJECT_SOURCE_DIR}/mlx/types/complex.h ops.h ) @@ -43,6 +46,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index fb397b669..673d9cd14 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -9,7 +9,7 @@ namespace mlx::core { namespace { -enum BinaryOpType { +enum class BinaryOpType { ScalarScalar, ScalarVector, VectorScalar, @@ -20,17 +20,17 @@ enum BinaryOpType { BinaryOpType get_binary_op_type(const array& a, const array& b) { BinaryOpType bopt; if (a.data_size() == 1 && b.data_size() == 1) { - bopt = ScalarScalar; + bopt = BinaryOpType::ScalarScalar; } else if (a.data_size() == 1 && b.flags().contiguous) { - bopt = ScalarVector; + bopt = BinaryOpType::ScalarVector; } else if (b.data_size() == 1 && a.flags().contiguous) { - bopt = VectorScalar; + bopt = BinaryOpType::VectorScalar; } else if ( a.flags().row_contiguous && b.flags().row_contiguous || a.flags().col_contiguous && b.flags().col_contiguous) { - bopt = VectorVector; + bopt = BinaryOpType::VectorVector; } else { - bopt = General; + bopt = BinaryOpType::General; } return bopt; } @@ -42,11 +42,11 @@ void set_binary_op_output_data( BinaryOpType bopt, bool donate_with_move = false) { switch (bopt) { - case ScalarScalar: + case BinaryOpType::ScalarScalar: out.set_data( allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); break; - case ScalarVector: + case BinaryOpType::ScalarVector: if (b.is_donatable() && b.itemsize() == out.itemsize()) { if (donate_with_move) { out.move_shared_buffer(b); @@ -61,7 +61,7 @@ void set_binary_op_output_data( b.flags()); } break; - case VectorScalar: + case BinaryOpType::VectorScalar: if (a.is_donatable() && a.itemsize() == out.itemsize()) { if (donate_with_move) { out.move_shared_buffer(a); @@ -76,7 +76,7 @@ void set_binary_op_output_data( a.flags()); } break; - case VectorVector: + case BinaryOpType::VectorVector: if (a.is_donatable() && a.itemsize() == out.itemsize()) { if (donate_with_move) { out.move_shared_buffer(a); @@ -97,7 +97,7 @@ void set_binary_op_output_data( a.flags()); } break; - case General: + case BinaryOpType::General: if (a.is_donatable() && a.flags().row_contiguous && a.itemsize() == out.itemsize() && a.size() == out.size()) { if (donate_with_move) { @@ -424,25 +424,25 @@ void binary_op( set_binary_op_output_data(a, b, out, bopt); // The full computation is scalar scalar so call the base op once - if (bopt == ScalarScalar) { + if (bopt == BinaryOpType::ScalarScalar) { *(out.data()) = op(*a.data(), *b.data()); return; } // The full computation is scalar vector so delegate to the op - if (bopt == ScalarVector) { + if (bopt == BinaryOpType::ScalarVector) { opsv(a.data(), b.data(), out.data(), b.data_size()); return; } // The full computation is vector scalar so delegate to the op - if (bopt == VectorScalar) { + if (bopt == BinaryOpType::VectorScalar) { opvs(a.data(), b.data(), out.data(), a.data_size()); return; } // The full computation is vector vector so delegate to the op - if (bopt == VectorVector) { + if (bopt == BinaryOpType::VectorVector) { opvv(a.data(), b.data(), out.data(), out.size()); return; } @@ -475,17 +475,17 @@ void binary_op( // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous int dim = ndim; if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { - bopt = VectorVector; + bopt = BinaryOpType::VectorVector; dim = d; // Case 2: LxM and Fx1 where L and F are broadcastable and M is row // contiguous } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { - bopt = VectorScalar; + bopt = BinaryOpType::VectorScalar; dim = d; // Case 3: Lx1 and FxM where L and F are broadcastable and M is row // contiguous } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { - bopt = ScalarVector; + bopt = BinaryOpType::ScalarVector; dim = d; } @@ -495,20 +495,20 @@ void binary_op( size_t stride; if (dim == 0 || strides[dim - 1] < 16) { stride = 1; - bopt = General; + bopt = BinaryOpType::General; dim = ndim; } else { stride = strides[dim - 1]; } switch (bopt) { - case VectorVector: + case BinaryOpType::VectorVector: binary_op_dispatch_dims(a, b, out, opvv, dim, stride); break; - case VectorScalar: + case BinaryOpType::VectorScalar: binary_op_dispatch_dims(a, b, out, opvs, dim, stride); break; - case ScalarVector: + case BinaryOpType::ScalarVector: binary_op_dispatch_dims(a, b, out, opsv, dim, stride); break; default: diff --git a/mlx/backend/common/binary_two.h b/mlx/backend/common/binary_two.h index 3468cb61e..3ce2f7110 100644 --- a/mlx/backend/common/binary_two.h +++ b/mlx/backend/common/binary_two.h @@ -260,14 +260,14 @@ void binary_op( set_binary_op_output_data(a, b, out_b, bopt); // The full computation is scalar scalar so call the base op once - if (bopt == ScalarScalar) { + if (bopt == BinaryOpType::ScalarScalar) { std::tie(*(out_a.data()), *(out_b.data())) = op(*a.data(), *b.data()); return; } // The full computation is scalar vector so delegate to the op - if (bopt == ScalarVector) { + if (bopt == BinaryOpType::ScalarVector) { opsv( a.data(), b.data(), @@ -278,7 +278,7 @@ void binary_op( } // The full computation is vector scalar so delegate to the op - if (bopt == VectorScalar) { + if (bopt == BinaryOpType::VectorScalar) { opvs( a.data(), b.data(), @@ -289,7 +289,7 @@ void binary_op( } // The full computation is vector vector so delegate to the op - if (bopt == VectorVector) { + if (bopt == BinaryOpType::VectorVector) { opvv( a.data(), b.data(), @@ -327,17 +327,17 @@ void binary_op( // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous int dim = ndim; if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { - bopt = VectorVector; + bopt = BinaryOpType::VectorVector; dim = d; // Case 2: LxM and Fx1 where L and F are broadcastable and M is row // contiguous } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { - bopt = VectorScalar; + bopt = BinaryOpType::VectorScalar; dim = d; // Case 3: Lx1 and FxM where L and F are broadcastable and M is row // contiguous } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { - bopt = ScalarVector; + bopt = BinaryOpType::ScalarVector; dim = d; } @@ -347,20 +347,20 @@ void binary_op( size_t stride; if (dim == 0 || strides[dim - 1] < 16) { stride = 1; - bopt = General; + bopt = BinaryOpType::General; dim = ndim; } else { stride = strides[dim - 1]; } switch (bopt) { - case VectorVector: + case BinaryOpType::VectorVector: binary_op_dispatch_dims(a, b, out_a, out_b, opvv, dim, stride); break; - case VectorScalar: + case BinaryOpType::VectorScalar: binary_op_dispatch_dims(a, b, out_a, out_b, opvs, dim, stride); break; - case ScalarVector: + case BinaryOpType::ScalarVector: binary_op_dispatch_dims(a, b, out_a, out_b, opsv, dim, stride); break; default: diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index 3f4f09d0c..5a8495040 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -1,6 +1,7 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include +#include #ifdef ACCELERATE_NEW_LAPACK #include @@ -27,14 +28,16 @@ void slow_conv_1D( array out, const std::vector& padding, const std::vector& wt_strides, - const std::vector& wt_dilation) { + const std::vector& wt_dilation, + const std::vector& in_dilation, + bool flip) { const T* start_wt_ptr = wt.data(); const T* in_ptr = in.data(); T* out_ptr = out.data(); const int N = in.shape(0); // Batch size, should be the same as out.shape(0) - const int iH = in.shape(1); // Input spatial dim + const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim const int oH = out.shape(1); // Output spatial dim const int O = wt.shape(0); // Out channels const int C = wt.shape(2); // In channels @@ -61,12 +64,15 @@ void slow_conv_1D( for (int wh = 0; wh < wH; ++wh) { const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; - int ih = oh * wt_strides[0] - padding[0] + wh * wt_dilation[0]; + int wh_flip = flip ? (wH - wh - 1) : wh; + int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0]; - if (ih >= 0 && ih < iH) { + auto ih_div = std::div(ih, in_dilation[0]); + + if (ih >= 0 && ih < iH && ih_div.rem == 0) { for (int c = 0; c < C; ++c) { r += static_cast( - in_ptr[ih * in_stride_H + c * in_stride_C]) * + in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) * static_cast(wt_ptr[c * wt_stride_C]); } // c @@ -90,14 +96,16 @@ void slow_conv_2D( array out, const std::vector& padding, const std::vector& wt_strides, - const std::vector& wt_dilation) { + const std::vector& wt_dilation, + const std::vector& in_dilation, + bool flip) { const T* st_wt_ptr = wt.data(); const T* st_in_ptr = in.data(); T* st_out_ptr = out.data(); const int N = in.shape(0); // Batch size, should be the same as out.shape(0) - const int iH = in.shape(1); // Input spatial dim - const int iW = in.shape(2); // Input spatial dim + const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim + const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim const int oH = out.shape(1); // Output spatial dim const int oW = out.shape(2); // Output spatial dim const int O = wt.shape(0); // Out channels @@ -120,6 +128,8 @@ void slow_conv_2D( const size_t out_stride_W = out.strides()[2]; const size_t out_stride_O = out.strides()[3]; + bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; + auto pt_conv_no_checks = [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { out_ptr += oh * out_stride_H + ow * out_stride_W; @@ -131,8 +141,10 @@ void slow_conv_2D( for (int wh = 0; wh < wH; ++wh) { for (int ww = 0; ww < wW; ++ww) { - int ih = ih_base + wh * wt_dilation[0]; - int iw = iw_base + ww * wt_dilation[1]; + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; @@ -153,25 +165,74 @@ void slow_conv_2D( } // o }; + int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; + int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; + + int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); + int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); + + int f_wgt_jump_h = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; + int f_wgt_jump_w = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; + + int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; + int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; + + std::vector base_h(f_out_jump_h); + std::vector base_w(f_out_jump_w); + + for (int i = 0; i < f_out_jump_h; ++i) { + int ih_loop = i * wt_strides[0] - padding[0] + init_h; + + int wh_base = 0; + while (wh_base < wH && ih_loop % in_dilation[0] != 0) { + wh_base++; + ih_loop += jump_h; + } + + base_h[i] = wh_base; + } + + for (int j = 0; j < f_out_jump_w; ++j) { + int iw_loop = j * wt_strides[1] - padding[1] + init_w; + + int ww_base = 0; + while (ww_base < wW && iw_loop % in_dilation[1] != 0) { + ww_base++; + iw_loop += jump_w; + } + + base_w[j] = ww_base; + } + auto pt_conv_all_checks = [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { out_ptr += oh * out_stride_H + ow * out_stride_W; + int ih_base = oh * wt_strides[0] - padding[0]; int iw_base = ow * wt_strides[1] - padding[1]; + int wh_base = base_h[oh % f_out_jump_h]; + int ww_base = base_w[ow % f_out_jump_w]; + for (int o = 0; o < O; ++o) { float r = 0.; - for (int wh = 0; wh < wH; ++wh) { - for (int ww = 0; ww < wW; ++ww) { - int ih = ih_base + wh * wt_dilation[0]; - int iw = iw_base + ww * wt_dilation[1]; + for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { + for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + + int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; + int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; + const T* in_ptr_pt = - in_ptr + ih * in_stride_H + iw * in_stride_W; + in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; for (int c = 0; c < C; ++c) { r += static_cast(in_ptr_pt[0]) * @@ -191,13 +252,17 @@ void slow_conv_2D( }; int oH_border_0 = 0; - int oH_border_1 = (padding[0] + wt_strides[0] + 1) / wt_strides[0]; - int oH_border_2 = (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]; + int oH_border_1 = + is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; + int oH_border_2 = std::max( + oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]); int oH_border_3 = oH; int oW_border_0 = 0; - int oW_border_1 = (padding[1] + wt_strides[0] + 1) / wt_strides[1]; - int oW_border_2 = (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]; + int oW_border_1 = + is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; + int oW_border_2 = std::max( + oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]); int oW_border_3 = oW; for (int n = 0; n < N; ++n) { @@ -246,15 +311,18 @@ void dispatch_slow_conv_1D( array out, const std::vector& padding, const std::vector& wt_strides, - const std::vector& wt_dilation) { + const std::vector& wt_dilation, + const std::vector& in_dilation, + bool flip) { if (in.dtype() == float32) { - return slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation); + return slow_conv_1D( + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else if (in.dtype() == float16) { return slow_conv_1D( - in, wt, out, padding, wt_strides, wt_dilation); + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else if (in.dtype() == bfloat16) { return slow_conv_1D( - in, wt, out, padding, wt_strides, wt_dilation); + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else { throw std::invalid_argument( "[Convolution::eval] got unsupported data type."); @@ -267,15 +335,18 @@ void dispatch_slow_conv_2D( array out, const std::vector& padding, const std::vector& wt_strides, - const std::vector& wt_dilation) { + const std::vector& wt_dilation, + const std::vector& in_dilation, + bool flip) { if (in.dtype() == float32) { - return slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation); + return slow_conv_2D( + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else if (in.dtype() == float16) { return slow_conv_2D( - in, wt, out, padding, wt_strides, wt_dilation); + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else if (in.dtype() == bfloat16) { return slow_conv_2D( - in, wt, out, padding, wt_strides, wt_dilation); + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } else { throw std::invalid_argument( "[Convolution::eval] got unsupported data type."); @@ -493,13 +564,16 @@ void conv_1D_cpu( array out, const std::vector& padding, const std::vector& wt_strides, - const std::vector& wt_dilation) { - if (wt_dilation[0] == 1) { + const std::vector& wt_dilation, + const std::vector& in_dilation, + bool flip) { + if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { return explicit_gemm_conv_1D_cpu( in, wt, out, padding, wt_strides, wt_dilation); } - return dispatch_slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation); + return dispatch_slow_conv_1D( + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } void conv_2D_cpu( @@ -508,8 +582,11 @@ void conv_2D_cpu( array out, const std::vector& padding, const std::vector& wt_strides, - const std::vector& wt_dilation) { - return dispatch_slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation); + const std::vector& wt_dilation, + const std::vector& in_dilation, + bool flip) { + return dispatch_slow_conv_2D( + in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip); } } // namespace @@ -523,12 +600,26 @@ void Convolution::eval(const std::vector& inputs, array& out) { // 2D convolution if (in.ndim() == (2 + 2)) { return conv_2D_cpu( - in, wt, out, padding_, kernel_strides_, kernel_dilation_); + in, + wt, + out, + padding_, + kernel_strides_, + kernel_dilation_, + input_dilation_, + flip_); } // 1D convolution else if (in.ndim() == (1 + 2)) { return conv_1D_cpu( - in, wt, out, padding_, kernel_strides_, kernel_dilation_); + in, + wt, + out, + padding_, + kernel_strides_, + kernel_dilation_, + input_dilation_, + flip_); } // Throw error else { diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index c65028d95..53b7a65f7 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -87,6 +87,7 @@ DEFAULT(Reshape) DEFAULT(Round) DEFAULT(Scan) DEFAULT(Scatter) +DEFAULT(Select) DEFAULT(Sigmoid) DEFAULT(Sign) DEFAULT(Sin) diff --git a/mlx/backend/common/ops.h b/mlx/backend/common/ops.h index 8b2d7ab58..b5b0953b2 100644 --- a/mlx/backend/common/ops.h +++ b/mlx/backend/common/ops.h @@ -7,6 +7,10 @@ namespace mlx::core::detail { +namespace { +constexpr float inf = std::numeric_limits::infinity(); +} // namespace + typedef union { int i; float f; @@ -588,4 +592,11 @@ struct LogicalOr { }; }; +struct Select { + template + T operator()(bool condition, T x, T y) { + return condition ? x : y; + } +}; + } // namespace mlx::core::detail diff --git a/mlx/backend/common/select.cpp b/mlx/backend/common/select.cpp new file mode 100644 index 000000000..1daa771b3 --- /dev/null +++ b/mlx/backend/common/select.cpp @@ -0,0 +1,72 @@ +// Copyright © 2023 Apple Inc. + +#include + +#include "mlx/backend/common/ternary.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +void select_op( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + switch (out.dtype()) { + case bool_: + ternary_op(a, b, c, out, op); + break; + case uint8: + ternary_op(a, b, c, out, op); + break; + case uint16: + ternary_op(a, b, c, out, op); + break; + case uint32: + ternary_op(a, b, c, out, op); + break; + case uint64: + ternary_op(a, b, c, out, op); + break; + case int8: + ternary_op(a, b, c, out, op); + break; + case int16: + ternary_op(a, b, c, out, op); + break; + case int32: + ternary_op(a, b, c, out, op); + break; + case int64: + ternary_op(a, b, c, out, op); + break; + case float16: + ternary_op(a, b, c, out, op); + break; + case float32: + ternary_op(a, b, c, out, op); + break; + case bfloat16: + ternary_op(a, b, c, out, op); + break; + case complex64: + ternary_op(a, b, c, out, op); + break; + } +} + +} // namespace + +void Select::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 3); + const auto& condition = inputs[0]; + const auto& a = inputs[1]; + const auto& b = inputs[2]; + select_op(condition, a, b, out, detail::Select()); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h new file mode 100644 index 000000000..52d202df7 --- /dev/null +++ b/mlx/backend/common/ternary.h @@ -0,0 +1,226 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/common/ops.h" +#include "mlx/backend/common/utils.h" +namespace mlx::core { + +namespace { + +// TODO: Add support for more combinations of input types. +enum class TernaryOpType { + ScalarScalarScalar, + General, +}; + +TernaryOpType +get_ternary_op_type(const array& a, const array& b, const array& c) { + TernaryOpType topt; + if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) { + topt = TernaryOpType::ScalarScalarScalar; + } else { + topt = TernaryOpType::General; + } + return topt; +} + +void set_ternary_op_output_data( + const array& a, + const array& b, + const array& c, + array& out, + TernaryOpType topt, + bool donate_with_move = false) { + switch (topt) { + case TernaryOpType::ScalarScalarScalar: + out.set_data( + allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags()); + break; + case TernaryOpType::General: + out.set_data(allocator::malloc_or_wait(out.nbytes())); + break; + } +} + +template +void ternary_op_dims1( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t c_idx = 0; + for (size_t i = 0; i < out.size(); ++i) { + dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); + a_idx += a.strides()[0]; + b_idx += b.strides()[0]; + c_idx += c.strides()[0]; + } +} + +template +void ternary_op_dims2( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t c_idx = 0; + size_t out_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); + a_idx += a.strides()[1]; + b_idx += b.strides()[1]; + c_idx += c.strides()[1]; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1]; + } +} + +template +void ternary_op_dims3( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t c_idx = 0; + size_t out_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + for (size_t k = 0; k < a.shape()[2]; ++k) { + dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); + a_idx += a.strides()[2]; + b_idx += b.strides()[2]; + c_idx += c.strides()[2]; + } + a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; + b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; + c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2]; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1]; + } +} + +template +void ternary_op_dims4( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + + U* dst = out.data(); + size_t a_idx = 0; + size_t b_idx = 0; + size_t c_idx = 0; + size_t out_idx = 0; + for (size_t i = 0; i < a.shape()[0]; ++i) { + for (size_t j = 0; j < a.shape()[1]; ++j) { + for (size_t k = 0; k < a.shape()[2]; ++k) { + for (size_t ii = 0; ii < a.shape()[3]; ++ii) { + dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); + a_idx += a.strides()[3]; + b_idx += b.strides()[3]; + c_idx += c.strides()[3]; + } + a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3]; + b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3]; + c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3]; + } + a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2]; + b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2]; + c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2]; + } + a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1]; + b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1]; + c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1]; + } +} + +template +void ternary_op_dispatch_dims( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + switch (out.ndim()) { + case 1: + ternary_op_dims1(a, b, c, out, op); + return; + case 2: + ternary_op_dims2(a, b, c, out, op); + return; + case 3: + ternary_op_dims3(a, b, c, out, op); + return; + case 4: + ternary_op_dims4(a, b, c, out, op); + return; + } + + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + U* dst = out.data(); + for (size_t i = 0; i < out.size(); i++) { + int a_idx = elem_to_loc(i, a.shape(), a.strides()); + int b_idx = elem_to_loc(i, b.shape(), b.strides()); + int c_idx = elem_to_loc(i, c.shape(), c.strides()); + dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]); + } +} + +template +void ternary_op( + const array& a, + const array& b, + const array& c, + array& out, + Op op) { + TernaryOpType topt = get_ternary_op_type(a, b, c); + set_ternary_op_output_data(a, b, c, out, topt); + + // The full computation is scalar-scalar-scalar so we call the base op once. + if (topt == TernaryOpType::ScalarScalarScalar) { + *(out.data()) = op(*a.data(), *b.data(), *c.data()); + return; + } + + ternary_op_dispatch_dims(a, b, c, out, op); +} + +} // namespace + +} // namespace mlx::core diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 063c283fe..fe764c494 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -4,7 +4,7 @@ add_custom_command( ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${CMAKE_C_COMPILER} - ${CMAKE_SOURCE_DIR} + ${PROJECT_SOURCE_DIR} DEPENDS make_compiled_preamble.sh kernels/compiled_preamble.h kernels/unary.h diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index cab27b715..d8e4538ae 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/metal.h" @@ -23,16 +23,6 @@ void* Buffer::raw_ptr() { namespace metal { -static bool cache_enabled_ = true; - -bool cache_enabled() { - return cache_enabled_; -} - -void set_cache_enabled(bool enabled) { - cache_enabled_ = enabled; -} - namespace { BufferCache::BufferCache(MTL::Device* device) @@ -158,9 +148,23 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), buffer_cache_(device_), - peak_allocated_size_(0), block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()), - gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {} + gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()), + max_pool_size_(block_limit_) {} + +size_t MetalAllocator::set_cache_limit(size_t limit) { + std::swap(limit, max_pool_size_); + return limit; +}; + +size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) { + std::swap(limit, block_limit_); + relaxed_ = relaxed; + gc_limit_ = std::min( + block_limit_, + static_cast(0.95 * device_->recommendedMaxWorkingSetSize())); + return limit; +}; Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // Metal doesn't like empty buffers @@ -175,10 +179,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // Try the cache MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); - + size_t pool_size = get_cache_memory(); if (!buf) { + size_t mem_required = get_active_memory() + pool_size + size; + // If there is too much memory pressure, fail (likely causes a wait). - if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) { + if (!(allow_swap && relaxed_) && mem_required >= block_limit_) { return Buffer{nullptr}; } @@ -186,10 +192,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // If we have a lot of memory pressure, check if we can reclaim some memory // from the cache - if (device_->currentAllocatedSize() + size >= gc_limit_) { - size_t min_bytes_to_free = - size + device_->currentAllocatedSize() - gc_limit_; - buffer_cache_.release_cached_buffers(min_bytes_to_free); + if (mem_required >= gc_limit_) { + buffer_cache_.release_cached_buffers(mem_required - gc_limit_); } // Allocate new buffer if needed @@ -198,15 +202,22 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { buf = device_->newBuffer(size, res_opt); } - peak_allocated_size_ = - std::max(peak_allocated_size_, device_->currentAllocatedSize()); + // Maintain the cache below the requested limit + if (pool_size >= max_pool_size_) { + auto thread_pool = metal::new_scoped_memory_pool(); + buffer_cache_.release_cached_buffers(pool_size - max_pool_size_); + } + + active_memory_ += buf->length(); + peak_memory_ = std::max(peak_memory_, active_memory_); return Buffer{static_cast(buf)}; } void MetalAllocator::free(Buffer buffer) { auto buf = static_cast(buffer.ptr()); - if (cache_enabled()) { + active_memory_ -= buf->length(); + if (max_pool_size_ > 0) { buffer_cache_.recycle_to_cache(buf); } else { buf->release(); @@ -218,6 +229,22 @@ MetalAllocator& allocator() { return allocator_; } +size_t set_cache_limit(size_t limit) { + return allocator().set_cache_limit(limit); +} +size_t set_memory_limit(size_t limit, bool relaxed /* = true */) { + return allocator().set_memory_limit(limit, relaxed); +} +size_t get_active_memory() { + return allocator().get_active_memory(); +} +size_t get_peak_memory() { + return allocator().get_peak_memory(); +} +size_t get_cache_memory() { + return allocator().get_cache_memory(); +} + } // namespace metal } // namespace mlx::core diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 45a58bc13..a31cb5fb4 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -24,6 +24,9 @@ class BufferCache { MTL::Buffer* reuse_from_cache(size_t size); void recycle_to_cache(MTL::Buffer* buf); void release_cached_buffers(size_t min_bytes_to_free); + size_t pool_size() { + return pool_size_; + } private: struct BufferHolder { @@ -54,6 +57,17 @@ class MetalAllocator : public allocator::Allocator { public: virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual void free(Buffer buffer) override; + size_t get_active_memory() { + return active_memory_; + }; + size_t get_peak_memory() { + return peak_memory_; + }; + size_t get_cache_memory() { + return buffer_cache_.pool_size(); + }; + size_t set_cache_limit(size_t limit); + size_t set_memory_limit(size_t limit, bool relaxed); private: MTL::Device* device_; @@ -64,9 +78,12 @@ class MetalAllocator : public allocator::Allocator { BufferCache buffer_cache_; // Allocation stats - size_t peak_allocated_size_; size_t block_limit_; size_t gc_limit_; + size_t active_memory_{0}; + size_t peak_memory_{0}; + size_t max_pool_size_; + bool relaxed_{true}; }; MetalAllocator& allocator(); diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 4ade8da17..426f6aefe 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -7,80 +7,72 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/kernels/conv_params.h" #include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/matmul.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" +using namespace mlx::steel; + namespace mlx::core { namespace { -void explicit_gemm_conv_1D_gpu( +template +void explicit_gemm_conv_ND_gpu( const Stream& s, metal::Device& d, const array& in, const array& wt, array out, - const MLXConvParams<1>& conv_params) { - // Pad input - std::vector padded_shape = { - conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.C}; - array in_padded(padded_shape, in.dtype(), nullptr, {}); + const MLXConvParams& conv_params) { + // Prepare unfolding array + std::vector unfolded_shape = { + static_cast(out.size() / conv_params.O), + static_cast(wt.size() / conv_params.O)}; + array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); - // Fill with zeros - copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s); + in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); - // Pick input slice from padded - size_t data_offset = conv_params.pad[0] * in_padded.strides()[1]; - array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); - in_padded_slice.copy_shared_buffer( - in_padded, - in_padded.strides(), - in_padded.flags(), - in_padded_slice.size(), - data_offset); + // Prepare unfolding kernel + std::ostringstream kname; + kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N; + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); - // Copy input values into the slice - copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s); + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, in_unfolded, 1); - // Make strided view - std::vector strided_shape = { - conv_params.N, conv_params.oS[0], conv_params.wS[0], conv_params.C}; + compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2); - std::vector strided_strides = { - in_padded.strides()[0], - in_padded.strides()[1] * conv_params.str[0], - in_padded.strides()[1], - in_padded.strides()[2]}; - auto flags = in_padded.flags(); + // Launch unfolding kernel + int tgp_x = std::min(conv_params.C, 64); + tgp_x = 32 * ((tgp_x + 32 - 1) / 32); + int tgp_y = 256 / tgp_x; - array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); - in_strided_view.copy_shared_buffer( - in_padded, strided_strides, flags, in_strided_view.size(), 0); + MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1); + MTL::Size grid_dims = MTL::Size( + conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); - // Materialize strided view - std::vector strided_reshape = { - conv_params.N * conv_params.oS[0], conv_params.wS[0] * conv_params.C}; - array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); - copy_gpu(in_strided_view, in_strided, CopyType::General, s); + compute_encoder->dispatchThreads(grid_dims, group_dims); // Perform gemm - std::vector copies = {in_padded, in_strided}; + std::vector copies; return steel_matmul( s, d, - /*a = */ in_strided, + /*a = */ in_unfolded, /*b = */ wt, /*c = */ out, - /*M = */ strided_reshape[0], + /*M = */ unfolded_shape[0], /*N = */ conv_params.O, - /*K = */ strided_reshape[1], + /*K = */ unfolded_shape[1], /*batch_size_out = */ 1, - /*a_cols = */ strided_reshape[1], - /*b_cols = */ strided_reshape[1], + /*a_cols = */ unfolded_shape[1], + /*b_cols = */ unfolded_shape[1], /*a_transposed = */ false, /*b_transposed = */ true, /*copies = */ copies); @@ -94,7 +86,9 @@ void conv_1D_gpu( array out, const std::vector& padding, const std::vector& wt_strides, - const std::vector& wt_dilation) { + const std::vector& wt_dilation, + const std::vector& in_dilation, + bool flip) { // Make conv params MLXConvParams<1> conv_params{ /* const int N = */ in.shape(0), @@ -105,24 +99,19 @@ void conv_1D_gpu( /* const int oS[NDIM] = */ {out.shape(1)}, /* const int str[NDIM] = */ {wt_strides[0]}, /* const int pad[NDIM] = */ {padding[0]}, - /* const int dil[NDIM] = */ {wt_dilation[0]}, + /* const int kdil[NDIM] = */ {wt_dilation[0]}, + /* const int idil[NDIM] = */ {in_dilation[0]}, /* const size_t in_strides[NDIM + 2] = */ {in.strides()[0], in.strides()[1], in.strides()[2]}, /* const size_t wt_strides[NDIM + 2] = */ {wt.strides()[0], wt.strides()[1], wt.strides()[2]}, /* const size_t out_strides[NDIM + 2] = */ {out.strides()[0], out.strides()[1], out.strides()[2]}, - }; + /* const int groups = */ 1, + /* const bool flip = */ flip}; // Direct to explicit gemm conv - if (wt_dilation[0] == 1) { - explicit_gemm_conv_1D_gpu(s, d, in, wt, out, conv_params); - } - - // Direct to fallback conv - else { - throw std::invalid_argument("[conv_1D_gpu] Dilation needs to be 1."); - } + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); } void slow_conv_2D_gpu( @@ -168,113 +157,262 @@ void implicit_gemm_conv_2D_gpu( const array& wt, array out, const MLXConvParams<2>& conv_params) { - int bm = 32, bn = 32, bk = 16; + // Deduce implicit gemm size + int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; + int implicit_N = conv_params.O; + int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C; + + // Determine block and warp tiles int wm = 2, wn = 2; + int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32; + int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32; + int bk = 16; + + if (implicit_N <= 16) { + bn = 8; + wm = 4; + wn = 1; + } + + int tn = (implicit_N + bn - 1) / bn; + int tm = (implicit_M + bm - 1) / bm; + int swizzle_log = 0; + + // Fix small channel specialization + int n_channel_specialization = 0; + int channel_k_iters = ((conv_params.C + bk - 1) / bk); + int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters; + + if (conv_params.C <= 2) { + gemm_k_iters = (implicit_K + bk - 1) / bk; + n_channel_specialization = conv_params.C; + } else if (conv_params.C <= 4) { + gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk; + n_channel_specialization = conv_params.C; + } + + bool small_filter = (!n_channel_specialization) && + (conv_params.wS[0] <= 16 && conv_params.wS[1] <= 16); + + // Fix host side helper params + int sign = (conv_params.flip ? -1 : 1); + int ijw = conv_params.in_strides[2] * conv_params.kdil[1]; + int ijh = conv_params.in_strides[1] * conv_params.kdil[0]; + + int inp_jump_w = sign * ijw; + int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw); + int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh - + sign * (conv_params.wS[1] - 1) * ijw; + + // Build implicit gemm params + ImplicitGemmConv2DParams gemm_params{ + /* const int M = */ implicit_M, + /* const int N = */ implicit_N, + /* const int K = */ implicit_K, + + /* const int gemm_k_iterations = */ gemm_k_iters, + + /* const int inp_jump_w = */ inp_jump_w, + /* const int inp_jump_h = */ inp_jump_h, + /* const int inp_jump_c = */ inp_jump_c, + + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int swizzle_log = */ swizzle_log}; + + // Determine kernel std::ostringstream kname; kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" - << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; + << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_" + << (n_channel_specialization ? std::to_string(n_channel_specialization) + : "l") + << "_filter_" << (small_filter ? 's' : 'l'); // Encode and dispatch kernel auto compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; - int implicit_N = conv_params.O; - - size_t grid_dim_x = (implicit_N + bn - 1) / bn; - size_t grid_dim_y = (implicit_M + bm - 1) / bm; + // Deduce grid launch dimensions + int tile = 1 << swizzle_log; + size_t grid_dim_y = (tm + tile - 1) / tile; + size_t grid_dim_x = tn * tile; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1); + // Encode arrays set_array_buffer(compute_encoder, in, 0); set_array_buffer(compute_encoder, wt, 1); set_array_buffer(compute_encoder, out, 2); + // Encode params compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); + compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4); + + // Launch kernel compute_encoder->dispatchThreadgroups(grid_dims, group_dims); } -void explicit_gemm_conv_2D_gpu( +void implicit_gemm_conv_2D_general_gpu( const Stream& s, metal::Device& d, const array& in, const array& wt, array out, const MLXConvParams<2>& conv_params) { - // Pad input - std::vector padded_shape = { - conv_params.N, - conv_params.iS[0] + 2 * conv_params.pad[0], - conv_params.iS[1] + 2 * conv_params.pad[1], - conv_params.C}; - array in_padded(padded_shape, in.dtype(), nullptr, {}); + // Deduce implicit gemm size + int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; + int implicit_N = conv_params.O; + int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C; - // Fill with zeros - copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s); + // Determine block and warp tiles + int wm = 2, wn = 2; - // Pick input slice from padded - size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] + - conv_params.pad[1] * in_padded.strides()[2]; - array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); - in_padded_slice.copy_shared_buffer( - in_padded, - in_padded.strides(), - in_padded.flags(), - in_padded_slice.size(), - data_offset); + // Make jump params + int f_wgt_jump_h = + std::lcm(conv_params.idil[0], conv_params.kdil[0]) / conv_params.kdil[0]; + int f_wgt_jump_w = + std::lcm(conv_params.idil[1], conv_params.kdil[1]) / conv_params.kdil[1]; - // Copy input values into the slice - copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s); + int f_out_jump_h = + std::lcm(conv_params.idil[0], conv_params.str[0]) / conv_params.str[0]; + int f_out_jump_w = + std::lcm(conv_params.idil[1], conv_params.str[1]) / conv_params.str[1]; - // Make strided view - std::vector strided_shape = { - conv_params.N, - conv_params.oS[0], - conv_params.oS[1], - conv_params.wS[0], - conv_params.wS[1], - conv_params.C}; + int adj_out_h = (conv_params.oS[0] + f_out_jump_h - 1) / f_out_jump_h; + int adj_out_w = (conv_params.oS[1] + f_out_jump_w - 1) / f_out_jump_w; + int adj_out_hw = adj_out_h * adj_out_w; + int adj_implicit_m = conv_params.N * adj_out_hw; - std::vector strided_strides = { - in_padded.strides()[0], - in_padded.strides()[1] * conv_params.str[0], - in_padded.strides()[2] * conv_params.str[1], - in_padded.strides()[1], - in_padded.strides()[2], - in_padded.strides()[3]}; - auto flags = in_padded.flags(); + Conv2DGeneralJumpParams jump_params{ + /* const int f_wgt_jump_h = */ f_wgt_jump_h, + /* const int f_wgt_jump_w = */ f_wgt_jump_w, - array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); - in_strided_view.copy_shared_buffer( - in_padded, strided_strides, flags, in_strided_view.size(), 0); + /* const int f_out_jump_h = */ f_out_jump_h, + /* const int f_out_jump_w = */ f_out_jump_w, - // Materialize strided view - std::vector strided_reshape = { - conv_params.N * conv_params.oS[0] * conv_params.oS[1], - conv_params.wS[0] * conv_params.wS[1] * conv_params.C}; - array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); - copy_gpu(in_strided_view, in_strided, CopyType::General, s); + /* const int adj_out_h = */ adj_out_h, + /* const int adj_out_w = */ adj_out_w, + /* const int adj_out_hw = */ adj_out_hw, + /* const int adj_implicit_m = */ adj_implicit_m}; - // Perform gemm - std::vector copies = {in_padded, in_strided}; - return steel_matmul( - s, - d, - /*a = */ in_strided, - /*b = */ wt, - /*c = */ out, - /*M = */ strided_reshape[0], - /*N = */ conv_params.O, - /*K = */ strided_reshape[1], - /*batch_size_out = */ 1, - /*a_cols = */ strided_reshape[1], - /*b_cols = */ strided_reshape[1], - /*a_transposed = */ false, - /*b_transposed = */ true, - /*copies = */ copies); + // Make base info + std::vector base_h(f_out_jump_h); + std::vector base_w(f_out_jump_w); + + int jump_h = conv_params.flip ? -conv_params.kdil[0] : conv_params.kdil[0]; + int jump_w = conv_params.flip ? -conv_params.kdil[1] : conv_params.kdil[1]; + + int init_h = + (conv_params.flip ? (conv_params.wS[0] - 1) * conv_params.kdil[0] : 0); + int init_w = + (conv_params.flip ? (conv_params.wS[1] - 1) * conv_params.kdil[1] : 0); + + for (int i = 0; i < f_out_jump_h; ++i) { + int ih_loop = i * conv_params.str[0] - conv_params.pad[0] + init_h; + + int wh_base = 0; + while (wh_base < conv_params.wS[0] && ih_loop % conv_params.idil[0] != 0) { + wh_base++; + ih_loop += jump_h; + } + + int wh_size = + ((conv_params.wS[0] - wh_base) + f_wgt_jump_h - 1) / f_wgt_jump_h; + base_h[i] = {wh_base, wh_size}; + } + + for (int j = 0; j < f_out_jump_w; ++j) { + int iw_loop = j * conv_params.str[1] - conv_params.pad[1] + init_w; + + int ww_base = 0; + while (ww_base < conv_params.wS[1] && iw_loop % conv_params.idil[1] != 0) { + ww_base++; + iw_loop += jump_w; + } + + int ww_size = + ((conv_params.wS[1] - ww_base) + f_wgt_jump_w - 1) / f_wgt_jump_w; + base_w[j] = {ww_base, ww_size}; + } + + // Collect block sizes + int bm = adj_implicit_m >= 8192 && conv_params.C >= 64 ? 64 : 32; + int bn = (bm == 64 && implicit_N >= 64) ? 64 : 32; + int bk = 16; + + int tn = (implicit_N + bn - 1) / bn; + int tm = (adj_implicit_m + bm - 1) / bm; + int swizzle_log = 0; + + // Get channel iteration info + int channel_k_iters = ((conv_params.C + bk - 1) / bk); + int gemm_k_iters = channel_k_iters; + + // Fix host side helper params + int sign = (conv_params.flip ? -1 : 1); + int ijw = conv_params.in_strides[2] * conv_params.kdil[1]; + int ijh = conv_params.in_strides[1] * conv_params.kdil[0]; + + int inp_jump_w = sign * ijw; + int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw); + int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh - + sign * (conv_params.wS[1] - 1) * ijw; + + // Build implicit gemm params + ImplicitGemmConv2DParams gemm_params{ + /* const int M = */ implicit_M, + /* const int N = */ implicit_N, + /* const int K = */ implicit_K, + + /* const int gemm_k_iterations = */ gemm_k_iters, + + /* const int inp_jump_w = */ inp_jump_w, + /* const int inp_jump_h = */ inp_jump_h, + /* const int inp_jump_c = */ inp_jump_c, + + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int swizzle_log = */ swizzle_log}; + + // Determine kernel + std::ostringstream kname; + kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm + << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; + + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + // Deduce grid launch dimensions + int tile = 1 << swizzle_log; + size_t grid_dim_y = (tm + tile - 1) / tile; + size_t grid_dim_x = tn * tile; + size_t grid_dim_z = f_out_jump_h * f_out_jump_w; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z); + + // Encode arrays + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, wt, 1); + set_array_buffer(compute_encoder, out, 2); + + // Encode params + compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); + compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4); + compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5); + + compute_encoder->setBytes( + base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6); + compute_encoder->setBytes( + base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7); + + // Launch kernel + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); } void winograd_conv_2D_gpu( @@ -299,6 +437,7 @@ void winograd_conv_2D_gpu( // Fill with zeros array zero_arr = array(0, in.dtype()); copy_gpu(zero_arr, in_padded, CopyType::Scalar, s); + copies_w.push_back(zero_arr); // Pick input slice from padded size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] + @@ -327,7 +466,8 @@ void winograd_conv_2D_gpu( /* const int oS[NDIM] = */ {out.shape(1), out.shape(2)}, /* const int str[NDIM] = */ {1, 1}, /* const int pad[NDIM] = */ {0, 0}, - /* const int dil[NDIM] = */ {1, 1}, + /* const int kdil[NDIM] = */ {1, 1}, + /* const int idil[NDIM] = */ {1, 1}, /* const size_t in_strides[NDIM + 2] = */ {in_padded.strides()[0], in_padded.strides()[1], @@ -337,6 +477,8 @@ void winograd_conv_2D_gpu( {wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]}, /* const size_t out_strides[NDIM + 2] = */ {out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]}, + /* const int groups = */ 1, + /* const bool flip = */ false, }; int O_c = conv_params.O; @@ -460,6 +602,8 @@ void conv_2D_gpu( const std::vector& padding, const std::vector& wt_strides, const std::vector& wt_dilation, + const std::vector& in_dilation, + bool flip, std::vector& copies) { // Make conv params MLXConvParams<2> conv_params{ @@ -471,37 +615,47 @@ void conv_2D_gpu( /* const int oS[NDIM] = */ {out.shape(1), out.shape(2)}, /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]}, /* const int pad[NDIM] = */ {padding[0], padding[1]}, - /* const int dil[NDIM] = */ {wt_dilation[0], wt_dilation[1]}, + /* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]}, + /* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]}, /* const size_t in_strides[NDIM + 2] = */ {in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]}, /* const size_t wt_strides[NDIM + 2] = */ {wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]}, /* const size_t out_strides[NDIM + 2] = */ {out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]}, + /* const int groups = */ 1, + /* const bool flip = */ flip, }; + bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1; + bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; + bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; + + bool inp_large = (conv_params.in_strides[0] >= 1ul << 18); + bool channels_large = (conv_params.C + conv_params.O) >= 512; + bool channels_med = (conv_params.C + conv_params.O) >= 256; + // Direct to winograd conv - if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && - conv_params.C >= 64 && conv_params.O >= 64 && conv_params.wS[0] == 3 && - conv_params.wS[1] == 3 && conv_params.str[0] == 1 && - conv_params.str[1] == 1 && conv_params.dil[0] == 1 && - conv_params.dil[1] == 1) { - winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); + if (!flip && is_stride_one && is_kdil_one && is_idil_one && + conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && + conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && + (channels_large || (channels_med && inp_large))) { + return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); } // Direct to implicit gemm conv - else if (conv_params.C % 32 == 0 && conv_params.O % 32 == 0) { - implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) && + (conv_params.O <= 16 || conv_params.O % 16 == 0)) { + return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) { + return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params); } // Direct to explicit gemm conv - else if (wt_dilation[0] == 1 && wt_dilation[1] == 1) { - explicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); - } - - // Direct to fallback conv else { - slow_conv_2D_gpu(s, d, in, wt, out, conv_params); + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); } } @@ -532,11 +686,31 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { // 2D conv if (out.ndim() == 4) { conv_2D_gpu( - s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_, copies); + s, + d, + in, + wt, + out, + padding_, + kernel_strides_, + kernel_dilation_, + input_dilation_, + flip_, + copies); } // 1D conv else if (out.ndim() == 3) { - conv_1D_gpu(s, d, in, wt, out, padding_, kernel_strides_, kernel_dilation_); + conv_1D_gpu( + s, + d, + in, + wt, + out, + padding_, + kernel_strides_, + kernel_dilation_, + input_dilation_, + flip_); } // Throw error else { diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 4eeb8858e..56a312a5d 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -142,7 +142,28 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { // Get kernel name std::ostringstream kname; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; - kname << "scatter" << type_to_name(out) << idx_type_name; + + int idx_ndim = nidx ? inputs[1].ndim() : 0; + bool index_nd1_specialization = (idx_ndim == 1); + + // Bail from fast path (1d index specialization) if scatter dims aren't + // the outermost dims and contiguous since update access won't be raster + // order. + for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) { + index_nd1_specialization &= (axes_[i] == i); + } + + // Bail from fast path (1d index specialization) if any of the dims are + // broadcasted, since we can't rely on linear indexing in that case. + for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) { + index_nd1_specialization &= inputs[i].flags().row_contiguous; + } + + if (index_nd1_specialization) { + kname << "scatter_1d_index" << type_to_name(out) << idx_type_name; + } else { + kname << "scatter" << type_to_name(out) << idx_type_name; + } switch (reduce_type_) { case Scatter::None: kname << "_none"; @@ -170,85 +191,106 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setComputePipelineState(kernel); - // Collect all idx shapes and strides into one place - int idx_ndim = nidx ? inputs[1].ndim() : 0; - std::vector idx_shapes; - std::vector idx_strides; - - for (int i = 0; i < nidx; ++i) { - idx_shapes.insert( - idx_shapes.end(), - inputs[i + 1].shape().begin(), - inputs[i + 1].shape().end()); - - idx_strides.insert( - idx_strides.end(), - inputs[i + 1].strides().begin(), - inputs[i + 1].strides().end()); - } - // Set all the buffers set_array_buffer(compute_encoder, upd, 1); set_array_buffer(compute_encoder, out, 2); // Set update info - size_t upd_ndim = upd.ndim(); + uint upd_ndim = upd.ndim(); size_t upd_size = 1; for (int i = idx_ndim; i < upd.ndim(); ++i) { upd_size *= upd.shape(i); } - if (upd_ndim == 0) { - // Need placeholders so Metal doesn't compalain - int shape_ = 0; - size_t stride_ = 0; - compute_encoder->setBytes(&shape_, sizeof(int), 3); - compute_encoder->setBytes(&stride_, sizeof(size_t), 4); - } else { - compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3); + + if (index_nd1_specialization) { + bool upd_col_contiguous = upd.flags().col_contiguous; compute_encoder->setBytes( - upd.strides().data(), upd_ndim * sizeof(size_t), 4); - } - compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); - compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); - - // Set output info - size_t out_ndim = out.ndim(); - if (out_ndim == 0) { - // Need placeholders so Metal doesn't compalain - int shape_ = 0; - size_t stride_ = 0; - compute_encoder->setBytes(&shape_, sizeof(int), 7); - compute_encoder->setBytes(&stride_, sizeof(size_t), 8); - } else { - compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7); + out.shape().data(), out.shape().size() * sizeof(int), 3); compute_encoder->setBytes( - out.strides().data(), out_ndim * sizeof(size_t), 8); - } - compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); - compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); + out.strides().data(), out.strides().size() * sizeof(size_t), 4); + compute_encoder->setBytes(&upd_size, sizeof(size_t), 5); + compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6); - // Set index info - if (idx_ndim == 0) { - // Add a 0 in idx_shapes and strides to avoid the missing buffer binding - // error in the metal API. - idx_shapes.push_back(0); - idx_strides.push_back(0); - } - compute_encoder->setBytes( - idx_shapes.data(), idx_shapes.size() * sizeof(int), 11); - compute_encoder->setBytes( - idx_strides.data(), idx_strides.size() * sizeof(size_t), 12); - compute_encoder->setBytes(&idx_ndim, sizeof(int), 13); + // Set index buffers + for (int i = 1; i < nidx + 1; ++i) { + set_array_buffer(compute_encoder, inputs[i], 20 + i); + } - // Set index buffers - for (int i = 1; i < nidx + 1; ++i) { - set_array_buffer(compute_encoder, inputs[i], 20 + i); - } + // Launch grid + MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); + MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); - // Launch grid - MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); - MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); + } else { + // Collect all idx shapes and strides into one place + std::vector idx_shapes; + std::vector idx_strides; + + for (int i = 0; i < nidx; ++i) { + idx_shapes.insert( + idx_shapes.end(), + inputs[i + 1].shape().begin(), + inputs[i + 1].shape().end()); + + idx_strides.insert( + idx_strides.end(), + inputs[i + 1].strides().begin(), + inputs[i + 1].strides().end()); + } + + if (upd_ndim == 0) { + // Need placeholders so Metal doesn't compalain + int shape_ = 0; + size_t stride_ = 0; + compute_encoder->setBytes(&shape_, sizeof(int), 3); + compute_encoder->setBytes(&stride_, sizeof(size_t), 4); + } else { + compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3); + compute_encoder->setBytes( + upd.strides().data(), upd_ndim * sizeof(size_t), 4); + } + compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); + compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); + + // Set output info + size_t out_ndim = out.ndim(); + if (out_ndim == 0) { + // Need placeholders so Metal doesn't compalain + int shape_ = 0; + size_t stride_ = 0; + compute_encoder->setBytes(&shape_, sizeof(int), 7); + compute_encoder->setBytes(&stride_, sizeof(size_t), 8); + } else { + compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7); + compute_encoder->setBytes( + out.strides().data(), out_ndim * sizeof(size_t), 8); + } + compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); + compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); + + // Set index info + if (idx_ndim == 0) { + // Add a 0 in idx_shapes and strides to avoid the missing buffer binding + // error in the metal API. + idx_shapes.push_back(0); + idx_strides.push_back(0); + } + compute_encoder->setBytes( + idx_shapes.data(), idx_shapes.size() * sizeof(int), 11); + compute_encoder->setBytes( + idx_strides.data(), idx_strides.size() * sizeof(size_t), 12); + compute_encoder->setBytes(&idx_ndim, sizeof(int), 13); + + // Set index buffers + for (int i = 1; i < nidx + 1; ++i) { + set_array_buffer(compute_encoder, inputs[i], 20 + i); + } + + // Launch grid + MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); + MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } } } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index afd2fbc8a..b3721b6d4 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -3,11 +3,13 @@ set( ${CMAKE_CURRENT_SOURCE_DIR}/atomic.h ${CMAKE_CURRENT_SOURCE_DIR}/bf16.h ${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h + ${CMAKE_CURRENT_SOURCE_DIR}/binary.h ${CMAKE_CURRENT_SOURCE_DIR}/complex.h ${CMAKE_CURRENT_SOURCE_DIR}/defines.h ${CMAKE_CURRENT_SOURCE_DIR}/erf.h ${CMAKE_CURRENT_SOURCE_DIR}/indexing.h ${CMAKE_CURRENT_SOURCE_DIR}/reduce.h + ${CMAKE_CURRENT_SOURCE_DIR}/unary.h ${CMAKE_CURRENT_SOURCE_DIR}/utils.h ) @@ -27,6 +29,7 @@ set( "scan" "softmax" "sort" + "ternary" "unary" "gather" "scatter" @@ -48,11 +51,7 @@ endfunction(build_kernel_base) function(build_kernel KERNEL) set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal) - set(HEADERS_PADDED ${HEADERS}) - if(${KERNEL} STREQUAL "conv") - set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/conv.h) - endif() - build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS_PADDED}") + build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS}") endfunction(build_kernel) foreach(KERNEL ${KERNELS}) diff --git a/mlx/backend/metal/kernels/arg_reduce.metal b/mlx/backend/metal/kernels/arg_reduce.metal index f153b920d..f24a32ce8 100644 --- a/mlx/backend/metal/kernels/arg_reduce.metal +++ b/mlx/backend/metal/kernels/arg_reduce.metal @@ -11,8 +11,6 @@ template struct IndexValPair { uint32_t index; U val; - - IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {} }; template @@ -65,10 +63,10 @@ struct ArgMax { template IndexValPair simd_shuffle_down(IndexValPair data, uint16_t delta) { - return IndexValPair( + return IndexValPair{ simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta) - ); + }; } @@ -82,7 +80,6 @@ template const device size_t& ndim [[buffer(5)]], const device size_t& axis_stride [[buffer(6)]], const device size_t& axis_size [[buffer(7)]], - threadgroup IndexValPair *local_data [[threadgroup(0)]], uint gid [[thread_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], @@ -111,7 +108,9 @@ template auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim); auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim); - IndexValPair best(0, Op::init); + IndexValPair best{0, Op::init}; + + threadgroup IndexValPair local_data[32]; // Loop over the reduction axis in lsize*N_READS buckets for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) { @@ -172,7 +171,6 @@ template const device size_t& ndim [[buffer(5)]], \ const device size_t& axis_stride [[buffer(6)]], \ const device size_t& axis_size [[buffer(7)]], \ - threadgroup IndexValPair *local_data [[threadgroup(0)]], \ uint gid [[thread_position_in_grid]], \ uint lid [[thread_position_in_threadgroup]], \ uint lsize [[threads_per_threadgroup]], \ diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 4d449ab69..eff687231 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -2,16 +2,6 @@ #include "mlx/backend/metal/kernels/binary.h" -template -[[kernel]] void binary_op_s2s( - device const T* a, - device const T* b, - device U* c, - uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[0]); -} - - template [[kernel]] void binary_op_ss( device const T* a, diff --git a/mlx/backend/metal/kernels/compiled_preamble.h b/mlx/backend/metal/kernels/compiled_preamble.h index d5bf33696..12fdc8117 100644 --- a/mlx/backend/metal/kernels/compiled_preamble.h +++ b/mlx/backend/metal/kernels/compiled_preamble.h @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/kernels/binary.h" +#include "mlx/backend/metal/kernels/ternary.h" #include "mlx/backend/metal/kernels/unary.h" typedef half float16_t; diff --git a/mlx/backend/metal/kernels/conv.h b/mlx/backend/metal/kernels/conv.h deleted file mode 100644 index 1db3ebac8..000000000 --- a/mlx/backend/metal/kernels/conv.h +++ /dev/null @@ -1,481 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#pragma once - -#include -#include -#include - -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/conv_params.h" - -#define MLX_MTL_CONST static constant constexpr const - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -// Loading helper -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - int BM, - int BN, - int BK, - int vec_size, - int tgp_size, - int tgp_padding = 0> -struct Conv2DInputBlockLoader { - // Destination dimensions - MLX_MTL_CONST int dst_fd = BM; - MLX_MTL_CONST int dst_ld = BK + tgp_padding; - MLX_MTL_CONST int n_vecs = BK / vec_size; - - // Stride along block row within the block - MLX_MTL_CONST int bstride = tgp_size / n_vecs; - MLX_MTL_CONST int n_rows = dst_fd / bstride; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - const constant MLXConvParams<2>& params; - - int weight_h; - int weight_w; - - int offsets_n[n_rows]; - int offsets_oh[n_rows]; - int offsets_ow[n_rows]; - - /* Constructor */ - METAL_FUNC Conv2DInputBlockLoader( - const device T* src_, - threadgroup T* dst_, - const constant MLXConvParams<2>& params_, - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / n_vecs), - bj(vec_size * (thread_idx % n_vecs)), - dst(dst_ + bi * dst_ld + bj), - src(src_ + bj), - params(params_), - weight_h(0), - weight_w(0) { - int out_n_pixels = params.oS[0] * params.oS[1]; - - for (int i = 0; i < n_rows; ++i) { - int offset_nhw = tid.y * BM + bi + i * bstride; - offsets_n[i] = offset_nhw / out_n_pixels; - int hw = offset_nhw % out_n_pixels; - offsets_oh[i] = hw / params.oS[1]; - offsets_ow[i] = hw % params.oS[1]; - } - - (void)lid; - } - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { -#pragma clang loop unroll(full) - for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) { - int n = offsets_n[i]; - int oh = offsets_oh[i]; - int ow = offsets_ow[i]; - - int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0]; - int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1]; - - // Read from input if in bounds - if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) { - const device T* curr_src = src + n * params.in_strides[0] + - ih * params.in_strides[1] + iw * params.in_strides[2]; - -#pragma clang loop unroll(full) - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = curr_src[j]; - } - } - - // Zero pad otherwise - else { -#pragma clang loop unroll(full) - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = T(0); - } - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - if (++weight_w < params.wS[1]) { - return; - } - - weight_w = 0; - - if (++weight_h < params.wS[0]) { - return; - } - - weight_h = 0; - - src += BK; - } -}; - -template < - typename T, - int BM, - int BN, - int BK, - int vec_size, - int tgp_size, - int tgp_padding = 0> -struct Conv2DWeightBlockLoader { - // Destination dimensions - MLX_MTL_CONST int dst_fd = BN; - MLX_MTL_CONST int dst_ld = BK + tgp_padding; - MLX_MTL_CONST int n_vecs = BK / vec_size; - - // Stride along block row within the block - MLX_MTL_CONST int bstride = tgp_size / n_vecs; - MLX_MTL_CONST int n_rows = dst_fd / bstride; - - // Leading dimension for src - const int src_ld; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - const constant MLXConvParams<2>& params; - - int weight_h; - int weight_w; - - /* Constructor */ - METAL_FUNC Conv2DWeightBlockLoader( - const device T* src_, - threadgroup T* dst_, - const constant MLXConvParams<2>& params_, - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(params_.wt_strides[0]), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / n_vecs), - bj(vec_size * (thread_idx % n_vecs)), - dst(dst_ + bi * dst_ld + bj), - src(src_ + bi * src_ld + bj), - params(params_), - weight_h(0), - weight_w(0) { - (void)lid; - (void)tid; - } - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - const device T* curr_src = - src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2]; -#pragma clang loop unroll(full) - for (short i = 0; i < dst_fd; i += bstride) { -#pragma clang loop unroll(full) - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = curr_src[i * src_ld + j]; - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - if (++weight_w < params.wS[1]) { - return; - } - - weight_w = 0; - - if (++weight_h < params.wS[0]) { - return; - } - - weight_h = 0; - - src += BK; - } -}; - -/////////////////////////////////////////////////////////////////////////////// -// Transforms -/////////////////////////////////////////////////////////////////////////////// - -template -struct TransformNone { - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } -}; - -template -struct AccumHelper { - typedef float accum_type; -}; - -/////////////////////////////////////////////////////////////////////////////// -// MMA helper -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - int tgp_padding_a = 0, - int tgp_padding_b = 0, - typename AccumType = typename AccumHelper::accum_type, - typename Epilogue = TransformNone> -struct Conv2DBlockMMA { - // Warp tile size along M - MLX_MTL_CONST int TM = BM / (WM * 8); - // Warp tile size along N - MLX_MTL_CONST int TN = BN / (WN * 8); - - // Warp tile simdgroup matrix strides along M - MLX_MTL_CONST int TM_stride = 8 * WM; - // Warp tile simdgroup matrix strides along M - MLX_MTL_CONST int TN_stride = 8 * WN; - - // Leading dimensions of threadgroup A, B blocks - MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a; - MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b; - - // Strides of A, B along reduction axis - MLX_MTL_CONST short simd_stride_a = - transpose_a ? TM_stride : TM_stride * lda_tgp; - MLX_MTL_CONST short simd_stride_b = - transpose_b ? TN_stride * ldb_tgp : TN_stride; - - // Jump between elements - MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1; - MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1; - - // Offsets within threadgroup - const int tm; - const int tn; - - // Simdgroup matrices - simdgroup_matrix Asimd[TM]; - simdgroup_matrix Bsimd[TN]; - simdgroup_matrix results[TM * TN] = { - simdgroup_matrix(0)}; - - short sm; - short sn; - - /* Constructor */ - METAL_FUNC Conv2DBlockMMA( - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { - short qid = simd_lane_id / 4; - sm = (qid & 4) + (simd_lane_id / 2) % 4; - sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - } - - /* (BM, BK) X (BK, BN) multiply accumulate function */ - METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { -// Iterate over BK in blocks of 8 -#pragma clang loop unroll(full) - for (short kk = 0; kk < BK; kk += 8) { - short2 offset_a = - transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm); - short2 offset_b = - transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm); - - const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x; - const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x; - - simdgroup_barrier(mem_flags::mem_none); -// Load elements from threadgroup A as simdgroup matrices -#pragma clang loop unroll(full) - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = static_cast(As__[0]); - Asimd[i].thread_elements()[1] = static_cast(As__[jump_a]); - As__ += simd_stride_a; - } - - simdgroup_barrier(mem_flags::mem_none); -// Load elements from threadgroup B as simdgroup matrices -#pragma clang loop unroll(full) - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = static_cast(Bs__[0]); - Bsimd[j].thread_elements()[1] = static_cast(Bs__[jump_b]); - Bs__ += simd_stride_b; - } - - simdgroup_barrier(mem_flags::mem_none); -// Multiply and accumulate into result simdgroup matrices -#pragma clang loop unroll(full) - for (short i = 0; i < TM; i++) { -#pragma clang loop unroll(full) - for (short j = 0; j < TN; j++) { - simdgroup_multiply_accumulate( - results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]); - } - } - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device T* C, const int ldc) const { -#pragma clang loop unroll(full) - for (int i = 0; i < TM; i++) { -#pragma clang loop unroll(full) - for (int j = 0; j < TN; j++) { - C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] = - Epilogue::apply(results[i * TN + j].thread_elements()[0]); - C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] = - Epilogue::apply(results[i * TN + j].thread_elements()[1]); - } - } - } - - METAL_FUNC void - store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const { -#pragma clang loop unroll(full) - for (int i = 0; i < TM; i++) { - if (tm + i * TM_stride + sm < dst_tile_dims.y) { -#pragma clang loop unroll(full) - for (int j = 0; j < TN; j++) { - if (tn + j * TN_stride + sn < dst_tile_dims.x) { - C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] = - Epilogue::apply(results[i * TN + j].thread_elements()[0]); - } - - if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) { - C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] = - Epilogue::apply(results[i * TN + j].thread_elements()[1]); - } - } - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - typename AccumType = typename AccumHelper::accum_type, - typename Epilogue = TransformNone> -struct Conv2DImplicitGEMMKernel { - MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T); - MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T); - MLX_MTL_CONST short tgp_mem_size_a = - transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); - MLX_MTL_CONST short tgp_mem_size_b = - transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); - MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; - - MLX_MTL_CONST short tgp_size = WM * WN * 32; - MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4; - - using loader_a_t = - Conv2DInputBlockLoader; - using loader_b_t = - Conv2DWeightBlockLoader; - using mma_t = Conv2DBlockMMA< - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - tgp_padding_a, - tgp_padding_b, - AccumType, - Epilogue>; - - /* Main kernel function */ - static METAL_FUNC void run( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device T* C [[buffer(2)]], - const constant MLXConvParams<2>& params [[buffer(3)]], - threadgroup T* tgp_memory [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - const int c_row = tid.y * BM; - const int c_col = tid.x * BN; - const int K = params.wt_strides[0]; - const int N = params.O; - - B += c_col * K; - C += c_row * N + c_col; - - // Prepare threadgroup memory for loading - threadgroup T* As = tgp_memory; - threadgroup T* Bs = tgp_memory + tgp_mem_size_a; - - // Prepare threadgroup loading operations - loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid); - loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid); - - // Prepare threadgroup mma operation - mma_t mma_op(simd_gid, simd_lid); - - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Store results to device memory - mma_op.store_result(C, N); - } -}; \ No newline at end of file diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 77c72c48c..b977876ff 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -1,16 +1,102 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include +#include +#include +#include -#include "mlx/backend/metal/kernels/conv_params.h" + +#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/conv.h" +#define MLX_MTL_CONST static constant constexpr const using namespace metal; /////////////////////////////////////////////////////////////////////////////// -/// Slow and naive kernels +/// Naive unfold with dilation +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel]] void naive_unfold_Nd( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + const constant MLXConvParams* params [[buffer(2)]], + uint3 gid [[thread_position_in_grid]]) { + + int filter_size = params->C; + for(short i = 0; i < N; i++) filter_size *= params->wS[i]; + + int out_pixels = 1; + for(short i = 0; i < N; i++) out_pixels *= params->oS[i]; + + // Set out + out += gid.z * filter_size + gid.y * (params->C); + + // Corrdinates in input + int is[N] = {0}; + + // gid.z: N oS (Batch and row in unfolded output) + // gid.y: wS (Filter location to unfold input) + // gid.x: C (channel) + + int n = (gid.z) / out_pixels; + int oS = (gid.z) % out_pixels; + int wS = gid.y; + + bool valid = n < params->N; + + // Unroll dimensions + for (int i = N - 1; i >= 0; --i) { + int os_ = (oS % params->oS[i]); + int ws_ = (wS % params->wS[i]); + + ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_; + + int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i]; + int is_max = 1 + params->idil[i] * (params->iS[i] - 1); + + valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0); + + is[i] = is_ / params->idil[i]; + + oS /= params->oS[i]; + wS /= params->wS[i]; + } + + if(valid) { + size_t in_offset = n * params->in_strides[0]; + + for(int i = 0; i < N; ++i) { + in_offset += is[i] * params->in_strides[i + 1]; + } + + out[gid.x] = in[in_offset + gid.x]; + } else { + out[gid.x] = T(0); + } + +} + +#define instantiate_naive_unfold_nd(name, itype, n) \ + template [[host_name("naive_unfold_nd_" #name "_" #n)]] \ + [[kernel]] void naive_unfold_Nd( \ + const device itype* in [[buffer(0)]], \ + device itype* out [[buffer(1)]], \ + const constant MLXConvParams* params [[buffer(2)]], \ + uint3 gid [[thread_position_in_grid]]); + +#define instantiate_naive_unfold_nd_dims(name, itype) \ + instantiate_naive_unfold_nd(name, itype, 1) \ + instantiate_naive_unfold_nd(name, itype, 2) \ + instantiate_naive_unfold_nd(name, itype, 3) + +instantiate_naive_unfold_nd_dims(float32, float); +instantiate_naive_unfold_nd_dims(float16, half); +instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t); + +/////////////////////////////////////////////////////////////////////////////// +/// Slow and naive conv2d kernels /////////////////////////////////////////////////////////////////////////////// template = 0 && i < params.iS[0] && j >= 0 && j < params.iS[1]; in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0); @@ -116,59 +202,6 @@ instantiate_naive_conv_2d_blocks(float32, float); instantiate_naive_conv_2d_blocks(float16, half); instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t); -/////////////////////////////////////////////////////////////////////////////// -/// Implicit gemm kernels -/////////////////////////////////////////////////////////////////////////////// - -template -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d( - const device T* in [[buffer(0)]], - const device T* wt [[buffer(1)]], - device T* out [[buffer(2)]], - const constant MLXConvParams<2>& params [[buffer(3)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - - using gemm_kernel = Conv2DImplicitGEMMKernel; - - threadgroup T tgp_memory[gemm_kernel::tgp_mem_size]; - - gemm_kernel::run( - in, wt, out, - params, tgp_memory, - tid, lid, simd_gid, simd_lid - ); - -} - -#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \ - template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \ - [[kernel]] void implicit_gemm_conv_2d( \ - const device itype* in [[buffer(0)]], \ - const device itype* wt [[buffer(1)]], \ - device itype* out [[buffer(2)]], \ - const constant MLXConvParams<2>& params [[buffer(3)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); - -#define instantiate_implicit_2d_blocks(name, itype) \ - instantiate_implicit_conv_2d(name, itype, 32, 32, 32, 2, 2) \ - instantiate_implicit_conv_2d(name, itype, 32, 32, 16, 2, 2) \ - instantiate_implicit_conv_2d(name, itype, 64, 64, 16, 2, 2) - -instantiate_implicit_2d_blocks(float32, float); -instantiate_implicit_2d_blocks(float16, half); -instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); - /////////////////////////////////////////////////////////////////////////////// /// Winograd kernels /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/kernels/conv_params.h b/mlx/backend/metal/kernels/conv_params.h deleted file mode 100644 index b216bb976..000000000 --- a/mlx/backend/metal/kernels/conv_params.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#pragma once - -template -struct MLXConvParams { - const int N; // Batch size - const int C; // In channels - const int O; // Out channels - const int iS[NDIM]; // Input spatial dim - const int wS[NDIM]; // Weight spatial dim - const int oS[NDIM]; // Output spatial dim - const int str[NDIM]; // Kernel strides - const int pad[NDIM]; // Input padding - const int dil[NDIM]; // Kernel dilation - const size_t in_strides[NDIM + 2]; // In strides - const size_t wt_strides[NDIM + 2]; // Wt strides - const size_t out_strides[NDIM + 2]; // Out strides -}; diff --git a/mlx/backend/metal/kernels/scatter.metal b/mlx/backend/metal/kernels/scatter.metal index 7a94be7da..071effeea 100644 --- a/mlx/backend/metal/kernels/scatter.metal +++ b/mlx/backend/metal/kernels/scatter.metal @@ -13,6 +13,58 @@ using namespace metal; // Scatter kernel ///////////////////////////////////////////////////////////////////// +template \ +METAL_FUNC void scatter_1d_index_impl( + const device T *updates [[buffer(1)]], + device mlx_atomic *out [[buffer(2)]], + const constant int* out_shape [[buffer(3)]], + const constant size_t* out_strides [[buffer(4)]], + const constant size_t& upd_size [[buffer(5)]], + const constant bool& upd_col_contiguous [[buffer(6)]], + const thread array& idx_buffers, + uint2 gid [[thread_position_in_grid]]) { + + Op op; + + uint out_idx = 0; + for (int i = 0; i < NIDX; i++) { + auto idx_val = offset_neg_idx( + idx_buffers[i][gid.y], out_shape[i]); + out_idx += idx_val * out_strides[i]; + } + + if (!upd_col_contiguous) { + op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x); + } else { + op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x); + } +} + +#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \ +template \ +[[kernel]] void scatter_1d_index( \ + const device T *updates [[buffer(1)]], \ + device mlx_atomic *out [[buffer(2)]], \ + const constant int* out_shape [[buffer(3)]], \ + const constant size_t* out_strides [[buffer(4)]], \ + const constant size_t& upd_size [[buffer(5)]], \ + const constant bool& upd_col_contiguous [[buffer(6)]], \ + IDX_ARG(IdxT) \ + uint2 gid [[thread_position_in_grid]]) { \ + \ + const array idx_buffers = {IDX_ARR()}; \ + \ + return scatter_1d_index_impl( \ + updates, \ + out, \ + out_shape, \ + out_strides, \ + upd_size, \ + upd_col_contiguous, \ + idx_buffers, \ + gid); \ + \ +} template METAL_FUNC void scatter_impl( @@ -46,10 +98,14 @@ METAL_FUNC void scatter_impl( out_idx += idx_val * out_strides[ax]; } - auto out_offset = elem_to_loc( - ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); + if (upd_size > 1) { + auto out_offset = elem_to_loc( + ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); + out_idx += out_offset; + } + auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); - op.atomic_update(out, updates[upd_idx], out_idx + out_offset); + op.atomic_update(out, updates[upd_idx], out_idx); } #define make_scatter_impl(IDX_ARG, IDX_ARR) \ @@ -90,9 +146,11 @@ template \ axes, \ idxs, \ gid); \ -} +} -#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) +#define make_scatter(n) \ +make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) \ +make_scatter_1d_index(IDX_ARG_ ##n, IDX_ARR_ ##n) make_scatter(0) make_scatter(1) @@ -129,8 +187,21 @@ template [[host_name("scatter" name "_" #nidx)]] \ IDX_ARG(idx_t) \ uint2 gid [[thread_position_in_grid]]); +#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ +template [[host_name("scatter_1d_index" name "_" #nidx)]] \ +[[kernel]] void scatter_1d_index( \ + const device src_t *updates [[buffer(1)]], \ + device mlx_atomic *out [[buffer(2)]], \ + const constant int* out_shape [[buffer(3)]], \ + const constant size_t* out_strides [[buffer(4)]], \ + const constant size_t& upd_size [[buffer(5)]], \ + const constant bool& upd_col_contiguous [[buffer(6)]], \ + IDX_ARG(idx_t) \ + uint2 gid [[thread_position_in_grid]]); + #define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \ - instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) + instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \ + instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // Special case NINDEX=0 #define instantiate_scatter_nd0(name, type) \ diff --git a/mlx/backend/metal/kernels/steel/conv/conv.h b/mlx/backend/metal/kernels/steel/conv/conv.h new file mode 100644 index 000000000..e5065cea2 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/conv/conv.h @@ -0,0 +1,11 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/conv/loader.h" +#include "mlx/backend/metal/kernels/steel/conv/params.h" + +using namespace metal; +using namespace mlx::steel; \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal new file mode 100644 index 000000000..6f80622ad --- /dev/null +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal @@ -0,0 +1,189 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/steel/gemm/mma.h" + +#include "mlx/backend/metal/kernels/steel/conv/conv.h" +#include "mlx/backend/metal/kernels/steel/conv/params.h" +#include "mlx/backend/metal/kernels/bf16.h" + +using namespace metal; + +template +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<2>* params [[buffer(3)]], + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + + using namespace mlx::steel; + + (void)lid; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + constexpr short tgp_padding_a = 16 / sizeof(T); + constexpr short tgp_padding_b = 16 / sizeof(T); + + constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; + constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; + constexpr short shape_a_rows = (transpose_a ? BK : BM); + constexpr short shape_b_rows = (transpose_b ? BN : BK); + constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; + constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; + + constexpr short tgp_size = WM * WN * 32; + + // Input loader + + using loader_a_t = typename metal::conditional_t< + // Check for small channel specialization + N_CHANNELS != 0 && N_CHANNELS <= 4, + + // Go to small channel specialization + Conv2DInputBlockLoaderSmallChannels< + T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>, + + // Else go to general loader + typename metal::conditional_t< + // Check if filter size is small enough + SMALL_FILTER, + + // Go to small filter specialization + Conv2DInputBlockLoaderSmallFilter< + T, BM, BN, BK, tgp_size, tgp_padding_a>, + + // Else go to large filter generalization + Conv2DInputBlockLoaderLargeFilter< + T, BM, BN, BK, tgp_size, tgp_padding_a> + > + >; + + + // Weight loader + using loader_b_t = typename metal::conditional_t< + // Check for small channel specialization + N_CHANNELS != 0 && N_CHANNELS <= 4, + + // Go to small channel specialization + Conv2DWeightBlockLoaderSmallChannels< + T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>, + + // Else go to general loader + Conv2DWeightBlockLoader + >; + + using mma_t = BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + shape_a_cols, + shape_b_cols>; + + threadgroup T As[tgp_mem_size_a]; + threadgroup T Bs[tgp_mem_size_b]; + + const int tid_y = ((tid.y) << gemm_params->swizzle_log) + + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> gemm_params->swizzle_log; + + if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { + return; + } + + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int K = gemm_params->K; + const int N = gemm_params->N; + + B += c_col * K; + C += c_row * N + c_col; + + const int2 offsets_a(0, c_row); + const int2 offsets_b(0, c_col); + + // Prepare threadgroup loading operations + loader_a_t loader_a(A, As, offsets_a, params, gemm_params, simd_gid, simd_lid); + loader_b_t loader_b(B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + int gemm_k_iterations = gemm_params->gemm_k_iterations; + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + short tgp_bm = min(BM, gemm_params->M - c_row); + short tgp_bn = min(BN, gemm_params->N - c_col); + mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm)); + +} + +#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \ + template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \ + [[kernel]] void implicit_gemm_conv_2d( \ + const device itype* A [[buffer(0)]], \ + const device itype* B [[buffer(1)]], \ + device itype* C [[buffer(2)]], \ + const constant MLXConvParams<2>* params [[buffer(3)]], \ + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ + instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \ + instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \ + instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \ + instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \ + instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \ + instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) + +#define instantiate_implicit_2d_blocks(name, itype) \ + instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \ + instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \ + instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \ + instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \ + instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \ + instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) + +instantiate_implicit_2d_blocks(float32, float); +instantiate_implicit_2d_blocks(float16, half); +instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal new file mode 100644 index 000000000..4f355af23 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal @@ -0,0 +1,209 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/steel/gemm/mma.h" + +#include "mlx/backend/metal/kernels/steel/conv/conv.h" +#include "mlx/backend/metal/kernels/steel/conv/params.h" +#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h" +#include "mlx/backend/metal/kernels/bf16.h" + +using namespace metal; +using namespace mlx::steel; + +template > +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<2>* params [[buffer(3)]], + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], + const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], + const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], + const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + (void)lid; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + constexpr short tgp_padding_a = 16 / sizeof(T); + constexpr short tgp_padding_b = 16 / sizeof(T); + + constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; + constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; + constexpr short shape_a_rows = (transpose_a ? BK : BM); + constexpr short shape_b_rows = (transpose_b ? BN : BK); + constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; + constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; + + constexpr short tgp_size = WM * WN * 32; + + // Input loader + using loader_a_t = Conv2DInputBlockLoaderGeneral< + T, BM, BN, BK, tgp_size, tgp_padding_a>; + + // Weight loader + using loader_b_t = Conv2DWeightBlockLoaderGeneral< + T, BM, BN, BK, tgp_size, tgp_padding_b>; + + using mma_t = BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + shape_a_cols, + shape_b_cols>; + + threadgroup T As[tgp_mem_size_a]; + threadgroup T Bs[tgp_mem_size_b]; + + const int tid_y = ((tid.y) << gemm_params->swizzle_log) + + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> gemm_params->swizzle_log; + + if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { + return; + } + + const int tid_z = tid.z; + + const int base_oh = tid_z / jump_params->f_out_jump_w; + const int base_ow = tid_z % jump_params->f_out_jump_w; + + const int base_wh = base_h[base_oh].weight_base; + const int base_ww = base_w[base_ow].weight_base; + + const int base_wh_size = base_h[base_oh].weight_size; + const int base_ww_size = base_w[base_ow].weight_size; + + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int K = gemm_params->K; + + B += c_col * K; + + const int4 offsets_a(0, c_row, base_oh, base_ow); + const int2 offsets_b(0, c_col); + + // Prepare threadgroup loading operations + loader_a_t loader_a(A, As, offsets_a, params, jump_params, base_wh, base_ww, simd_gid, simd_lid); + loader_b_t loader_b(B, Bs, offsets_b, params, jump_params, base_wh, base_ww, simd_gid, simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + int gemm_k_iterations = base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + { + // Adjust for simdgroup and thread locatio + int offset_m = c_row + mma_op.sm + mma_op.tm; + int offset_n = c_col + mma_op.sn + mma_op.tn; + C += offset_n; + + if (offset_n >= gemm_params->N) + return; + + short diff = gemm_params->N - offset_n; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < mma_t::TM; i++) { + + int cm = offset_m + i * mma_t::TM_stride; + + int n = cm / jump_params->adj_out_hw; + int hw = cm % jump_params->adj_out_hw; + int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh; + int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; + + if(n < params->N && oh < params->oS[0] && ow < params->oS[1]) { + + int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2]; + + STEEL_PRAGMA_UNROLL + for (int j = 0; j < mma_t::TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = mma_op.results[i * mma_t::TN + j].thread_elements(); + int offset = offset_cm + (j * mma_t::TN_stride); + + // Apply epilogue and output C + if (j * mma_t::TN_stride < diff) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * mma_t::TN_stride + 1 < diff) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + + } + } + } + +} + +#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \ + template [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \ + [[kernel]] void implicit_gemm_conv_2d_general( \ + const device itype* A [[buffer(0)]], \ + const device itype* B [[buffer(1)]], \ + device itype* C [[buffer(2)]], \ + const constant MLXConvParams<2>* params [[buffer(3)]], \ + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \ + const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \ + const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \ + const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ + instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) + +#define instantiate_implicit_2d_blocks(name, itype) \ + instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \ + instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \ + instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \ + instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \ + instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \ + instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) + +instantiate_implicit_2d_blocks(float32, float); +instantiate_implicit_2d_blocks(float16, half); +instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/conv/loader.h b/mlx/backend/metal/kernels/steel/conv/loader.h new file mode 100644 index 000000000..f84a640f4 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/conv/loader.h @@ -0,0 +1,6 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h" +#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h" \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h new file mode 100644 index 000000000..dad496e81 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h @@ -0,0 +1,449 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/conv/params.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderLargeFilter { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant ImplicitGemmConv2DParams* gemm_params; + + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderLargeFilter( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_h(0), + weight_w(0) { + int out_n_pixels = params->oS[0] * params->oS[1]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + int oh = hw / params->oS[1]; + int ow = hw % params->oS[1]; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + + // Adjust for flip + if (params->flip) { + ih += (params->wS[0] - 1) * params->kdil[0]; + iw += (params->wS[1] - 1) * params->kdil[1]; + } + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + + iw * params->in_strides[2] + bj; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + int ih = read_ih[i] + weight_h * params->kdil[0]; + int iw = read_iw[i] + weight_w * params->kdil[1]; + + // Read from input if in bounds + if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) && + (iw >= 0 && iw < params->iS[1])) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = src[i][j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params->wS[1]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_w; + } + + return; + } + + weight_w = 0; + + if (++weight_h < params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_h; + } + + return; + } + + weight_h = 0; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_c; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderSmallFilter { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + using mask_t = short; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant ImplicitGemmConv2DParams* gemm_params; + + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + mask_t mask_h[n_rows]; + mask_t mask_w[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderSmallFilter( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_h(0), + weight_w(0) { + int out_n_pixels = params->oS[0] * params->oS[1]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + int oh = hw / params->oS[1]; + int ow = hw % params->oS[1]; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + + // Adjust for flip + if (params->flip) { + ih += (params->wS[0] - 1) * params->kdil[0]; + iw += (params->wS[1] - 1) * params->kdil[1]; + } + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + + iw * params->in_strides[2] + bj; + } + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + mask_h[i] = 0; + mask_w[i] = 0; + } + + for (short kh = 0; kh < params->wS[0]; kh++) { + short flip_h = params->flip ? params->wS[0] - kh - 1 : kh; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int n = read_n[i]; + int ih = read_ih[i] + flip_h * params->kdil[0]; + + bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0]; + + mask_h[i] |= (in_bounds << kh); + } + } + + for (short kw = 0; kw < params->wS[1]; kw++) { + short flip_w = params->flip ? params->wS[1] - kw - 1 : kw; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int iw = read_iw[i] + flip_w * params->kdil[1]; + + bool in_bounds = iw >= 0 && iw < params->iS[1]; + + mask_w[i] |= (in_bounds << kw); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + mask_t h_mask = mask_t(1) << weight_h; + mask_t w_mask = mask_t(1) << weight_w; + + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Read from input if in bounds + if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = src[i][j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params->wS[1]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_w; + } + + return; + } + + weight_w = 0; + + if (++weight_h < params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_h; + } + + return; + } + + weight_h = 0; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_c; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DWeightBlockLoader { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = + (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>* params; + + int weight_hw; + + const int read_n; + const bool do_read; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoader( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj), + params(params_), + weight_hw(0), + read_n(offsets.y + bi), + do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (BN != 8 || do_read) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = src[i * src_ld + j]; + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((read_n + i) < params->O) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = src[i * src_ld + j]; + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_hw < (params->wS[1] * params->wS[0])) { + src += params->wt_strides[2]; + return; + } + + weight_hw = 0; + + src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2]; + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h new file mode 100644 index 000000000..56027916e --- /dev/null +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h @@ -0,0 +1,319 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/conv/params.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct ChannelHelper { + STEEL_CONST short n_channels = n_channels_; + STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8; + STEEL_CONST short excess = vec_size - n_channels_; +}; + +template <> +struct ChannelHelper<1> { + STEEL_CONST short n_channels = 1; + STEEL_CONST short vec_size = 1; + STEEL_CONST short excess = 0; +}; + +template <> +struct ChannelHelper<2> { + STEEL_CONST short n_channels = 2; + STEEL_CONST short vec_size = 2; + STEEL_CONST short excess = 0; +}; + +template <> +struct ChannelHelper<3> { + STEEL_CONST short n_channels = 3; + STEEL_CONST short vec_size = 4; + STEEL_CONST short excess = 1; +}; + +template <> +struct ChannelHelper<4> { + STEEL_CONST short n_channels = 4; + STEEL_CONST short vec_size = 4; + STEEL_CONST short excess = 0; +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short n_channels, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderSmallChannels { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = ChannelHelper::vec_size; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant ImplicitGemmConv2DParams* gemm_params; + + short weight_hw; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderSmallChannels( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_hw(thread_idx % TCOLS) { + int out_n_pixels = params->oS[0] * params->oS[1]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + int oh = hw / params->oS[1]; + int ow = hw % params->oS[1]; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + + iw * params->in_strides[2]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (weight_hw >= params->wS[1] * params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + int wh = (weight_hw / params->wS[1]); + int ww = (weight_hw % params->wS[1]); + + int flip_h = params->flip ? params->wS[0] - wh - 1 : wh; + int flip_w = params->flip ? params->wS[1] - ww - 1 : ww; + + int weight_h = flip_h * params->kdil[0]; + int weight_w = flip_w * params->kdil[1]; + + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + int ih = read_ih[i] + weight_h; + int iw = read_iw[i] + weight_w; + + // Read from input if in bounds + if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) && + (iw >= 0 && iw < params->iS[1])) { + const device T* curr_src = src[i] + weight_h * params->in_strides[1] + + weight_w * params->in_strides[2]; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < n_channels; ++j) { + dst[is * dst_ld + j] = curr_src[j]; + } + + STEEL_PRAGMA_UNROLL + for (short j = n_channels; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_hw += TCOLS; + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short n_channels, + short tgp_padding = 0> +struct Conv2DWeightBlockLoaderSmallChannels { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = ChannelHelper::vec_size; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>* params; + + int weight_hw; + + const int read_n; + const bool do_read; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoaderSmallChannels( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld), + params(params_), + weight_hw(thread_idx % TCOLS), + read_n(offsets.y + bi), + do_read(read_n + BN <= gemm_params_->N) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (bi >= BROWS || bj >= BCOLS) + return; + + if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + + return; + } + + const device T* curr_src = src + weight_hw * params->wt_strides[2]; + + if (BN != 8 || do_read) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < n_channels; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + + STEEL_PRAGMA_UNROLL + for (short j = n_channels; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } else { + for (short i = 0; i < BROWS; i += TROWS) { + if (((read_n + i) < params->O)) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < n_channels; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + + STEEL_PRAGMA_UNROLL + for (short j = n_channels; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_hw += TCOLS; + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h new file mode 100644 index 000000000..3e396c2af --- /dev/null +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h @@ -0,0 +1,288 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/conv/params.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderGeneral { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant Conv2DGeneralJumpParams* jump_params; + + const short base_wh; + const short base_ww; + + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderGeneral( + const device T* src_, + threadgroup T* dst_, + const int4 offsets, + const constant MLXConvParams<2>* params_, + const constant Conv2DGeneralJumpParams* jump_params_, + const short base_wh_, + const short base_ww_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + jump_params(jump_params_), + base_wh(base_wh_), + base_ww(base_ww_), + weight_h(base_wh_), + weight_w(base_ww_) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / jump_params->adj_out_hw; + int hw = offset_nhw % jump_params->adj_out_hw; + int oh = + (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z; + int ow = + (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + bj; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + + // Read from input if in bounds + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_w += jump_params->f_wgt_jump_w; + if (weight_w < params->wS[1]) { + return; + } + + weight_w = base_ww; + + weight_h += jump_params->f_wgt_jump_h; + if (weight_h < params->wS[0]) { + return; + } + + weight_h = base_wh; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += BK; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DWeightBlockLoaderGeneral { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = + (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>* params; + const constant Conv2DGeneralJumpParams* jump_params; + + const short base_wh; + const short base_ww; + + short weight_h; + short weight_w; + + const int start_row; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoaderGeneral( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant Conv2DGeneralJumpParams* jump_params_, + const short base_wh_, + const short base_ww_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj), + params(params_), + jump_params(jump_params_), + base_wh(base_wh_), + base_ww(base_ww_), + weight_h(base_wh_), + weight_w(base_ww_), + start_row(offsets.y + bi) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + + if ((start_row + BN <= params->O)) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_w += jump_params->f_wgt_jump_w; + if (weight_w < params->wS[1]) { + return; + } + + weight_w = base_ww; + + weight_h += jump_params->f_wgt_jump_h; + if (weight_h < params->wS[0]) { + return; + } + + weight_h = base_wh; + + src += BK; + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/conv/params.h b/mlx/backend/metal/kernels/steel/conv/params.h new file mode 100644 index 000000000..f75851dc8 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/conv/params.h @@ -0,0 +1,62 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +template +struct MLXConvParams { + const int N; // Batch size + const int C; // In channels + const int O; // Out channels + const int iS[NDIM]; // Input spatial dim + const int wS[NDIM]; // Weight spatial dim + const int oS[NDIM]; // Output spatial dim + const int str[NDIM]; // Kernel strides + const int pad[NDIM]; // Input padding + const int kdil[NDIM]; // Kernel dilation + const int idil[NDIM]; // Input dilation + const size_t in_strides[NDIM + 2]; // In strides + const size_t wt_strides[NDIM + 2]; // Wt strides + const size_t out_strides[NDIM + 2]; // Out strides + const int groups; // Input channel groups + const bool flip; +}; + +namespace mlx { +namespace steel { + +struct ImplicitGemmConv2DParams { + const int M; + const int N; + const int K; + + const int gemm_k_iterations; + + const int inp_jump_w; + const int inp_jump_h; + const int inp_jump_c; + + const int tiles_n; + const int tiles_m; + const int swizzle_log; +}; + +struct Conv2DGeneralJumpParams { + const int f_wgt_jump_h; + const int f_wgt_jump_w; + + const int f_out_jump_h; + const int f_out_jump_w; + + const int adj_out_h; + const int adj_out_w; + const int adj_out_hw; + const int adj_implicit_m; +}; + +struct Conv2DGeneralBaseInfo { + int weight_base; + int weight_size; +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/gemm.h b/mlx/backend/metal/kernels/steel/gemm/gemm.h index be70bcacb..2e2b0f838 100644 --- a/mlx/backend/metal/kernels/steel/gemm/gemm.h +++ b/mlx/backend/metal/kernels/steel/gemm/gemm.h @@ -4,6 +4,7 @@ #include "mlx/backend/metal/kernels/steel/gemm/loader.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" #include "mlx/backend/metal/kernels/steel/gemm/transforms.h" #include "mlx/backend/metal/kernels/steel/utils.h" diff --git a/mlx/backend/metal/kernels/steel/gemm/mma.h b/mlx/backend/metal/kernels/steel/gemm/mma.h index 6f58bfcaf..d98625ae6 100644 --- a/mlx/backend/metal/kernels/steel/gemm/mma.h +++ b/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -2,9 +2,15 @@ #pragma once +#include +#include +#include + #include "mlx/backend/metal/kernels/steel/gemm/transforms.h" #include "mlx/backend/metal/kernels/steel/utils.h" +using namespace metal; + /////////////////////////////////////////////////////////////////////////////// // MMA helper /////////////////////////////////////////////////////////////////////////////// @@ -167,6 +173,9 @@ struct BlockMMA { C += (sm + tm) * ldc + (tn + sn); dst_tile_dims -= short2(tn + sn, sm + tm); + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + STEEL_PRAGMA_UNROLL for (int i = 0; i < TM; i++) { if (i * TM_stride < dst_tile_dims.y) { @@ -236,6 +245,9 @@ struct BlockMMA { D += (sm + tm) * ldd + tn + sn; dst_tile_dims -= short2(tn + sn, sm + tm); + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + STEEL_PRAGMA_UNROLL for (int i = 0; i < TM; i++) { if (i * TM_stride < dst_tile_dims.y) { diff --git a/mlx/backend/metal/kernels/steel/host.h b/mlx/backend/metal/kernels/steel/host.h deleted file mode 100644 index 6fb4e54c9..000000000 --- a/mlx/backend/metal/kernels/steel/host.h +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "mlx/backend/metal/kernels/steel/gemm/params.h" \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/utils.h b/mlx/backend/metal/kernels/steel/utils.h index a4b6aa261..c5550cef3 100644 --- a/mlx/backend/metal/kernels/steel/utils.h +++ b/mlx/backend/metal/kernels/steel/utils.h @@ -3,7 +3,6 @@ #pragma once #include -#include "mlx/backend/metal/kernels/steel/host.h" #define STEEL_CONST static constant constexpr const #define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") \ No newline at end of file diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h new file mode 100644 index 000000000..e0235d9dd --- /dev/null +++ b/mlx/backend/metal/kernels/ternary.h @@ -0,0 +1,10 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +struct Select { + template + T operator()(bool condition, T x, T y) { + return condition ? x : y; + } +}; diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal new file mode 100644 index 000000000..c351bed17 --- /dev/null +++ b/mlx/backend/metal/kernels/ternary.metal @@ -0,0 +1,201 @@ +// Copyright © 2023 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/ternary.h" + +template +[[kernel]] void ternary_op_v( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + uint index [[thread_position_in_grid]]) { + d[index] = Op()(a[index], b[index], c[index]); +} + +template +[[kernel]] void ternary_op_g_nd1( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const size_t& a_strides, + constant const size_t& b_strides, + constant const size_t& c_strides, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_strides); + auto b_idx = elem_to_loc_1(index, b_strides); + auto c_idx = elem_to_loc_1(index, c_strides); + d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_op_g_nd2( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + constant const size_t c_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + auto c_idx = elem_to_loc_2(index, c_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_op_g_nd3( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + constant const size_t c_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + auto c_idx = elem_to_loc_3(index, c_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_op_g_nd( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int shape[DIM], + constant const size_t a_strides[DIM], + constant const size_t b_strides[DIM], + constant const size_t c_strides[DIM], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); +} + +template +[[kernel]] void ternary_op_g( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim); + size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); + d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); +} + +#define instantiate_ternary_v(name, type, op) \ + template [[host_name(name)]] \ + [[kernel]] void ternary_op_v( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + uint index [[thread_position_in_grid]]); \ + +#define instantiate_ternary_g(name, type, op) \ + template [[host_name(name)]] \ + [[kernel]] void ternary_op_g( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const int* shape, \ + constant const size_t* a_strides, \ + constant const size_t* b_strides, \ + constant const size_t* c_strides, \ + constant const int& ndim, \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + +#define instantiate_ternary_g_dim(name, type, op, dims) \ + template [[host_name(name "_" #dims)]] \ + [[kernel]] void ternary_op_g_nd( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const int shape[dims], \ + constant const size_t a_strides[dims], \ + constant const size_t b_strides[dims], \ + constant const size_t c_strides[dims], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + +#define instantiate_ternary_g_nd(name, type, op) \ + template [[host_name(name "_1")]] \ + [[kernel]] void ternary_op_g_nd1( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const size_t& a_strides, \ + constant const size_t& b_strides, \ + constant const size_t& c_strides, \ + uint index [[thread_position_in_grid]]); \ + template [[host_name(name "_2")]] \ + [[kernel]] void ternary_op_g_nd2( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const size_t a_strides[2], \ + constant const size_t b_strides[2], \ + constant const size_t c_strides[2], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); \ + template [[host_name(name "_3")]] \ + [[kernel]] void ternary_op_g_nd3( \ + device const bool* a, \ + device const type* b, \ + device const type* c, \ + device type* d, \ + constant const size_t a_strides[3], \ + constant const size_t b_strides[3], \ + constant const size_t c_strides[3], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + instantiate_ternary_g_dim(name, type, op, 4) \ + instantiate_ternary_g_dim(name, type, op, 5) \ + +#define instantiate_ternary_all(name, tname, type, op) \ + instantiate_ternary_v("v" #name #tname, type, op) \ + instantiate_ternary_g("g" #name #tname, type, op) \ + instantiate_ternary_g_nd("g" #name #tname, type, op) \ + +#define instantiate_ternary_types(name, op) \ + instantiate_ternary_all(name, bool_, bool, op) \ + instantiate_ternary_all(name, uint8, uint8_t, op) \ + instantiate_ternary_all(name, uint16, uint16_t, op) \ + instantiate_ternary_all(name, uint32, uint32_t, op) \ + instantiate_ternary_all(name, uint64, uint64_t, op) \ + instantiate_ternary_all(name, int8, int8_t, op) \ + instantiate_ternary_all(name, int16, int16_t, op) \ + instantiate_ternary_all(name, int32, int32_t, op) \ + instantiate_ternary_all(name, int64, int64_t, op) \ + instantiate_ternary_all(name, float16, half, op) \ + instantiate_ternary_all(name, float32, float, op) \ + instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \ + instantiate_ternary_all(name, complex64, complex64_t, op) \ + +instantiate_ternary_types(select, Select) diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 6d086b775..e0d80ab10 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -9,6 +9,10 @@ #include "mlx/backend/metal/kernels/erf.h" #include "mlx/backend/metal/kernels/utils.h" +namespace { +constant float inf = metal::numeric_limits::infinity(); +} + struct Abs { template T operator()(T x) { diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 8ef1127b6..9c3d20b30 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -91,6 +91,30 @@ inline size_t elem_to_loc( return loc; } +template +inline uint3 elem_to_loc_3_nd( + uint3 elem, + constant const int shape[NDIM], + constant const size_t a_strides[NDIM], + constant const size_t b_strides[NDIM], + constant const size_t c_strides[NDIM]) { + uint3 loc = { + static_cast( + elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), + static_cast( + elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]), + static_cast( + elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])}; + for (int d = NDIM - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + loc.z += l * c_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + template inline uint2 elem_to_loc_2_nd( uint3 elem, @@ -150,6 +174,30 @@ inline size_t elem_to_loc( return loc; } +inline uint3 elem_to_loc_3_nd( + uint3 elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + int ndim) { + uint3 loc = { + static_cast( + elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), + static_cast( + elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]), + static_cast( + elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])}; + for (int d = ndim - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + loc.z += l * c_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + inline uint2 elem_to_loc_2_nd( uint3 elem, constant const int* shape, diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index c0b3cb19b..76b192d35 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -8,7 +8,7 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" -#include "mlx/backend/metal/kernels/steel/host.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" #include "mlx/backend/metal/matmul.h" #include "mlx/backend/metal/mps/gemm.h" #include "mlx/backend/metal/utils.h" diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 96cf87818..6035d2a7f 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -10,6 +10,10 @@ namespace mlx::core::metal { +bool is_available() { + return true; +} + int max_ops_per_buffer() { auto get_val = []() { if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) { diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index d249daac0..360481f81 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -11,16 +11,52 @@ namespace mlx::core::metal { -constexpr bool is_available() { -#ifdef _METAL_ - return true; -#else - return false; -#endif -} +bool is_available(); -bool cache_enabled(void); -void set_cache_enabled(bool enabled); +/* Get the actively used memory in bytes. + * + * Note, this will not always match memory use reported by the system because + * it does not include cached memory buffers. + * */ +size_t get_active_memory(); + +/* Get the peak amount of used memory in bytes. + * + * The maximum memory used is recorded from the beginning of the program + * execution. + * */ +size_t get_peak_memory(); + +/* Get the cache size in bytes. + * + * The cache includes memory not currently used that has not been returned + * to the system allocator. + * */ +size_t get_cache_memory(); + +/* Set the memory limit. + * Calls to malloc will wait on scheduled tasks if the limit is exceeded. If + * there are no more scheduled tasks an error will be raised if relaxed + * is false or memory will be allocated (including the potential for + * swap) if relaxed is true. + * + * The memory limit defaults to 1.5 times the maximum recommended working set + * size reported by the device. + * + * Returns the previous memory limit. + * */ +size_t set_memory_limit(size_t limit, bool relaxed = true); + +/* Set the free cache limit. + * If using more than the given limit, free memory will be reclaimed + * from the cache on the next allocation. To disable the cache, + * set the limit to 0. + * + * The cache limit defaults to the memory limit. + * + * Returns the previous cache limit. + * */ +size_t set_cache_limit(size_t limit); void new_stream(Stream stream); std::shared_ptr new_scoped_memory_pool(); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 38ec5993c..0f2716a1b 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -1,11 +1,11 @@ // Copyright © 2023-2024 Apple Inc. - #include #include #include #include #include "mlx/backend/common/binary.h" +#include "mlx/backend/common/ternary.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" @@ -43,24 +43,25 @@ void binary_op( std::ostringstream kname; switch (bopt) { - case ScalarScalar: + case BinaryOpType::ScalarScalar: kname << "ss"; break; - case ScalarVector: + case BinaryOpType::ScalarVector: kname << "sv"; break; - case VectorScalar: + case BinaryOpType::VectorScalar: kname << "vs"; break; - case VectorVector: + case BinaryOpType::VectorVector: kname << "vv"; break; - case General: + case BinaryOpType::General: kname << "g"; break; } kname << op << type_to_name(a); - if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { + if (bopt == BinaryOpType::General && + shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { kname << "_" << shape.size(); } @@ -80,7 +81,7 @@ void binary_op( set_array_buffer(compute_encoder, outputs[0], 2); set_array_buffer(compute_encoder, outputs[1], 3); - if (bopt == General) { + if (bopt == BinaryOpType::General) { auto ndim = shape.size(); if (ndim > 3) { compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); @@ -141,24 +142,25 @@ void binary_op( std::ostringstream kname; switch (bopt) { - case ScalarScalar: + case BinaryOpType::ScalarScalar: kname << "ss"; break; - case ScalarVector: + case BinaryOpType::ScalarVector: kname << "sv"; break; - case VectorScalar: + case BinaryOpType::VectorScalar: kname << "vs"; break; - case VectorVector: + case BinaryOpType::VectorVector: kname << "vv"; break; - case General: + case BinaryOpType::General: kname << "g"; break; } kname << op << type_to_name(a); - if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { + if (bopt == BinaryOpType::General && + shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { kname << "_" << shape.size(); } @@ -173,7 +175,7 @@ void binary_op( set_array_buffer(compute_encoder, donate_b ? out : b, 1); set_array_buffer(compute_encoder, out, 2); - if (bopt == General) { + if (bopt == BinaryOpType::General) { auto ndim = shape.size(); if (ndim > 3) { compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3); @@ -202,7 +204,94 @@ void binary_op( compute_encoder->dispatchThreads(grid_dims, group_dims); } else { // Launch a 1D grid of threads - size_t nthreads = bopt == General ? out.size() : out.data_size(); + size_t nthreads = + bopt == BinaryOpType::General ? out.size() : out.data_size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } +} + +void ternary_op( + const std::vector& inputs, + array& out, + const std::string op) { + assert(inputs.size() == 3); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& c = inputs[2]; + TernaryOpType topt = get_ternary_op_type(a, b, c); + set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */); + + if (out.size() == 0) { + return; + } + + // Try to collapse contiguous dims + auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); + auto& strides_a = strides[0]; + auto& strides_b = strides[1]; + auto& strides_c = strides[2]; + auto& strides_out = strides[3]; + + std::ostringstream kname; + if (topt == TernaryOpType::General) { + kname << "g"; + kname << op << type_to_name(b); + if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { + kname << "_" << shape.size(); + } + } else { + kname << "v"; + kname << op << type_to_name(b); + } + + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + auto kernel = d.get_kernel(kname.str()); + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, a, 0); + set_array_buffer(compute_encoder, b, 1); + set_array_buffer(compute_encoder, c, 2); + set_array_buffer(compute_encoder, out, 3); + + if (topt == TernaryOpType::General) { + auto ndim = shape.size(); + if (ndim > 3) { + compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); + compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7); + + if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { + compute_encoder->setBytes(&ndim, sizeof(int), 8); + } + } else { + // The shape is implicit in the grid for <= 3D + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); + compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6); + } + + // Launch up to 3D grid of threads + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = out.size() / (dim0 * dim1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + } + MTL::Size group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } else { + // Launch a 1D grid of threads + size_t nthreads = out.data_size(); MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size > nthreads) { @@ -430,8 +519,6 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&ndim, sizeof(size_t), 5); compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6); compute_encoder->setBytes(&axis_size, sizeof(size_t), 7); - compute_encoder->setThreadgroupMemoryLength( - simd_size * (sizeof(uint32_t) + in.itemsize()), 0); compute_encoder->dispatchThreads(grid_dims, group_dims); } } @@ -621,6 +708,10 @@ void Multiply::eval_gpu(const std::vector& inputs, array& out) { binary_op(inputs, out, "mul"); } +void Select::eval_gpu(const std::vector& inputs, array& out) { + ternary_op(inputs, out, "select"); +} + void Negative::eval_gpu(const std::vector& inputs, array& out) { unary_op(inputs, out, "neg"); } diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index 257fcbb5d..240e00c41 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include @@ -6,6 +6,10 @@ namespace mlx::core::metal { +bool is_available() { + return false; +} + void new_stream(Stream) {} std::shared_ptr new_scoped_memory_pool() { return nullptr; @@ -19,10 +23,21 @@ std::function make_task( "[metal::make_task] Cannot make GPU task without metal backend"); } -// No cache for CPU only -bool cache_enabled(void) { - return false; +// No-ops when Metal is not available. +size_t get_active_memory() { + return 0; +} +size_t get_peak_memory() { + return 0; +} +size_t get_cache_memory() { + return 0; +} +size_t set_memory_limit(size_t, bool) { + return 0; +} +size_t set_cache_limit(size_t) { + return 0; } -void set_cache_enabled(bool) {} } // namespace mlx::core::metal diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 8e66f56b3..4234eeb1c 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -80,6 +80,7 @@ NO_GPU(Reshape) NO_GPU(Round) NO_GPU(Scan) NO_GPU(Scatter) +NO_GPU(Select) NO_GPU(Sigmoid) NO_GPU(Sign) NO_GPU(Sin) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 700c07ced..e54778fe1 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -47,6 +47,10 @@ bool is_binary(const Primitive& p) { typeid(p) == typeid(Subtract)); } +bool is_ternary(const Primitive& p) { + return typeid(p) == typeid(Select); +} + bool is_broadcast(const Primitive& p) { return typeid(p) == typeid(Broadcast); } @@ -60,14 +64,16 @@ bool is_reduction(const Primitive& p) { } bool is_fusable(const Primitive& p) { - return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p); + return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) || + is_noop(p); } bool allows_shapeless(const Primitive& p) { return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) || is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) || typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) || - typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition); + typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) || + typeid(p) == typeid(Select); } Compiled::Compiled( diff --git a/mlx/dtype.h b/mlx/dtype.h index d52830485..fec1725f3 100644 --- a/mlx/dtype.h +++ b/mlx/dtype.h @@ -46,10 +46,6 @@ struct Dtype { }; }; -inline bool is_available(const Dtype& dtype) { - return true; -} - static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)}; static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)}; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 97d4a3a2d..04b2a8e5b 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include #include #include @@ -73,10 +74,24 @@ array arange( if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) { throw std::invalid_argument("[arange] Cannot compute length."); } - double real_size = std::ceil((stop - start) / step); - if (std::isnan(real_size)) { + + if (std::isinf(start) || std::isinf(stop)) { throw std::invalid_argument("[arange] Cannot compute length."); } + + // Check if start and stop specify a valid range because if not, we have to + // return an empty array + if (std::isinf(step) && + (step > 0 && start < stop || step < 0 && start > stop)) { + return array({start}, dtype); + } + + double real_size = std::ceil((stop - start) / step); + + if (real_size > INT_MAX) { + throw std::invalid_argument("[arange] Maximum size exceeded."); + } + int size = std::max(static_cast(real_size), 0); return array( {size}, @@ -1149,13 +1164,20 @@ array isneginf(const array& a, StreamOrDevice s /* = {} */) { } array where( - const array& condition, - const array& x, - const array& y, + const array& a, + const array& b, + const array& c, StreamOrDevice s /* = {} */) { - // TODO, fix this to handle the NaN case when x has infs - auto mask = astype(condition, bool_, s); - return add(multiply(x, mask, s), multiply(y, logical_not(mask, s), s), s); + auto condition = astype(a, bool_, s); + Dtype out_dtype = promote_types(b.dtype(), c.dtype()); + auto inputs = broadcast_arrays( + {condition, astype(b, out_dtype, s), astype(c, out_dtype, s)}, s); + + return array( + inputs[0].shape(), + out_dtype, + std::make_unique