diff --git a/.circleci/config.yml b/.circleci/config.yml index a2455aa19..9965c98e4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -114,7 +114,13 @@ jobs: name: Run CPP tests command: | DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests - DEVICE=cpu ./build/tests/tests + - run: + name: Build small binary + command: | + source env/bin/activate + cd build/ + cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON + make -j build_release: parameters: diff --git a/CMakeLists.txt b/CMakeLists.txt index 29b0a8784..be6ebe1c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,7 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_BUILD_GGUF "Include support for GGUF format" ON) option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON) +option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) if(NOT MLX_VERSION) @@ -109,7 +110,7 @@ elseif (MLX_BUILD_METAL) $ ) target_link_libraries( - mlx + mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) @@ -122,7 +123,7 @@ if (MLX_BUILD_CPU) if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") set(MLX_BUILD_ACCELERATE ON) - target_link_libraries(mlx ${ACCELERATE_LIBRARY}) + target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY}) add_compile_definitions(ACCELERATE_NEW_LAPACK) else() message(STATUS "Accelerate or arm neon not found, using default backend.") @@ -145,7 +146,7 @@ if (MLX_BUILD_CPU) message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) - target_link_libraries(mlx ${LAPACK_LIBRARIES}) + target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES}) # List blas after lapack otherwise we may accidentally incldue an old version # of lapack.h from the include dirs of blas. find_package(BLAS REQUIRED) @@ -160,7 +161,7 @@ if (MLX_BUILD_CPU) message(STATUS "Blas lib " ${BLAS_LIBRARIES}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) - target_link_libraries(mlx ${BLAS_LIBRARIES}) + target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES}) endif() else() set(MLX_BUILD_ACCELERATE OFF) @@ -175,6 +176,14 @@ target_include_directories( $ ) +FetchContent_Declare(fmt + GIT_REPOSITORY https://github.com/fmtlib/fmt.git + GIT_TAG 10.2.1 + EXCLUDE_FROM_ALL +) +FetchContent_MakeAvailable(fmt) +target_link_libraries(mlx PRIVATE fmt::fmt-header-only) + if (MLX_BUILD_PYTHON_BINDINGS) message(STATUS "Building Python bindings.") find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) diff --git a/docs/src/install.rst b/docs/src/install.rst index 693385e2c..7b99c3145 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -163,6 +163,8 @@ should point to the path to the built metal library. - ON * - MLX_BUILD_GGUF - ON + * - MLX_METAL_JIT + - OFF .. note:: @@ -196,9 +198,18 @@ GGUF, you can do: cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \ -DBUILD_SHARED_LIBS=ON \ - -DMLX_BUILD_CPU=ON \ + -DMLX_BUILD_CPU=OFF \ -DMLX_BUILD_SAFETENSORS=OFF \ -DMLX_BUILD_GGUF=OFF + -DMLX_METAL_JIT=ON + +THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which +contains pre-built GPU kernels. This substantially reduces the size of the +Metal library by run-time compiling kernels the first time they are used in MLX +on a given machine. Note run-time compilation incurs a cold-start cost which can +be anwywhere from a few hundred millisecond to a few seconds depending on the +application. Once a kernel is compiled, it will be cached by the system. The +Metal kernel cache persists accross reboots. Troubleshooting ^^^^^^^^^^^^^^^ diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index 673d9cd14..19f32bd6d 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -1,6 +1,8 @@ // Copyright © 2023 Apple Inc. #pragma once +#include + #include "mlx/allocator.h" #include "mlx/array.h" #include "mlx/backend/common/utils.h" diff --git a/mlx/backend/common/cholesky.cpp b/mlx/backend/common/cholesky.cpp index 2af5d8ddf..5fd9c8065 100644 --- a/mlx/backend/common/cholesky.cpp +++ b/mlx/backend/common/cholesky.cpp @@ -98,12 +98,4 @@ void Cholesky::eval(const std::vector& inputs, array& output) { cholesky_impl(inputs[0], output, upper_); } -std::pair, std::vector> Cholesky::vmap( - const std::vector& inputs, - const std::vector& axes) { - auto ax = axes[0] >= 0 ? 0 : -1; - auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; - return {{linalg::cholesky(a, upper_, stream())}, {ax}}; -} - } // namespace mlx::core diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index ccc7fb9c3..46dceb788 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -1,33 +1,80 @@ -add_custom_command( - OUTPUT compiled_preamble.cpp +function(make_jit_source SRC_NAME) + # This function takes a metal header file, + # runs the C preprocessesor on it, and makes + # the processed contents available as a string in a C++ function + # mlx::core::metal::${SRC_NAME}() + # + # To use the function, declare it in jit/includes.h and + # include jit/includes.h. + # + # Additional arguments to this function are treated as dependencies + # in the Cmake build system. + add_custom_command( + OUTPUT jit/${SRC_NAME}.cpp COMMAND /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh - ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp + ${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR} + ${SRC_NAME} "-D${MLX_METAL_VERSION}" DEPENDS make_compiled_preamble.sh - kernels/compiled_preamble.h - kernels/unary.h - kernels/binary.h - kernels/bf16.h - kernels/erf.h - kernels/expm1f.h - kernels/utils.h - kernels/bf16_math.h -) + kernels/${SRC_NAME}.h + ${ARGN} + ) + add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp) + add_dependencies(mlx ${SRC_NAME}) + target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp + ) +endfunction(make_jit_source) -add_custom_target( - compiled_preamble - DEPENDS compiled_preamble.cpp +make_jit_source( + utils + kernels/bf16.h + kernels/complex.h ) +make_jit_source( + unary_ops + kernels/erf.h + kernels/expm1f.h +) +make_jit_source(binary_ops) +make_jit_source(ternary_ops) +make_jit_source( + reduction + kernels/atomic.h + kernels/reduction/ops.h +) +make_jit_source(scatter) +make_jit_source(gather) -add_dependencies(mlx compiled_preamble) +if (MLX_METAL_JIT) + target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp + ) + make_jit_source(copy) + make_jit_source(unary) + make_jit_source(binary) + make_jit_source(binary_two) + make_jit_source(ternary) +else() + target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp + ) +endif() target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp @@ -46,7 +93,8 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp - ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp ) if (NOT MLX_METAL_PATH) diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp new file mode 100644 index 000000000..4b1e0f01d --- /dev/null +++ b/mlx/backend/metal/binary.cpp @@ -0,0 +1,322 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5; + +void binary_op( + const std::vector& inputs, + std::vector& outputs, + const std::string op) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, outputs[0], bopt, true); + set_binary_op_output_data(a, b, outputs[1], bopt, true); + + auto& out = outputs[0]; + if (out.size() == 0) { + return; + } + + // Try to collapse contiguous dims + auto [shape, strides] = collapse_contiguous_dims(a, b, out); + auto& strides_a = strides[0]; + auto& strides_b = strides[1]; + auto& strides_out = strides[2]; + + std::string kernel_name; + { + std::ostringstream kname; + switch (bopt) { + case BinaryOpType::ScalarScalar: + kname << "ss"; + break; + case BinaryOpType::ScalarVector: + kname << "sv"; + break; + case BinaryOpType::VectorScalar: + kname << "vs"; + break; + case BinaryOpType::VectorVector: + kname << "vv"; + break; + case BinaryOpType::General: + kname << "g"; + if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { + kname << shape.size(); + } else { + kname << "n"; + } + break; + } + kname << op << type_to_name(a); + kernel_name = kname.str(); + } + + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + + auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]); + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + // - If a is donated it goes to the first output + // - If b is donated it goes to the first output if a was not donated + // otherwise it goes to the second output + bool donate_a = a.data_shared_ptr() == nullptr; + bool donate_b = b.data_shared_ptr() == nullptr; + compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0); + compute_encoder.set_input_array( + donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1); + compute_encoder.set_output_array(outputs[0], 2); + compute_encoder.set_output_array(outputs[1], 3); + + if (bopt == BinaryOpType::General) { + auto ndim = shape.size(); + if (ndim > 3) { + compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); + } else { + // The shape is implicit in the grid for <= 3D + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); + } + + if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { + compute_encoder->setBytes(&ndim, sizeof(int), 7); + } + + // Launch up to 3D grid of threads + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = out.size() / (dim0 * dim1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + } + auto group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder.dispatchThreads(grid_dims, group_dims); + } else { + // Launch a 1D grid of threads + size_t nthreads = out.data_size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder.dispatchThreads(grid_dims, group_dims); + } +} + +void binary_op( + const std::vector& inputs, + array& out, + const std::string op) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt, true); + if (out.size() == 0) { + return; + } + + // Try to collapse contiguous dims + auto [shape, strides] = collapse_contiguous_dims(a, b, out); + auto& strides_a = strides[0]; + auto& strides_b = strides[1]; + auto& strides_out = strides[2]; + + std::string kernel_name; + { + std::ostringstream kname; + switch (bopt) { + case BinaryOpType::ScalarScalar: + kname << "ss"; + break; + case BinaryOpType::ScalarVector: + kname << "sv"; + break; + case BinaryOpType::VectorScalar: + kname << "vs"; + break; + case BinaryOpType::VectorVector: + kname << "vv"; + break; + case BinaryOpType::General: + kname << "g"; + if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { + kname << shape.size(); + } else { + kname << "n"; + } + break; + } + kname << op << type_to_name(a); + kernel_name = kname.str(); + } + + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + + auto kernel = get_binary_kernel(d, kernel_name, a, out); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + bool donate_a = a.data_shared_ptr() == nullptr; + bool donate_b = b.data_shared_ptr() == nullptr; + compute_encoder.set_input_array(donate_a ? out : a, 0); + compute_encoder.set_input_array(donate_b ? out : b, 1); + compute_encoder.set_output_array(out, 2); + + if (bopt == BinaryOpType::General) { + auto ndim = shape.size(); + if (ndim > 3) { + compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3); + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); + } else { + // The shape is implicit in the grid for <= 3D + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4); + } + + if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { + compute_encoder->setBytes(&ndim, sizeof(int), 6); + } + + // Launch up to 3D grid of threads + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = out.size() / (dim0 * dim1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + } + auto group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder.dispatchThreads(grid_dims, group_dims); + } else { + // Launch a 1D grid of threads + size_t nthreads = + bopt == BinaryOpType::General ? out.size() : out.data_size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder.dispatchThreads(grid_dims, group_dims); + } +} + +void Add::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "add"); +} + +void ArcTan2::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "arctan2"); +} + +void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { + switch (op_) { + case BitwiseBinary::And: + binary_op(inputs, out, "bitwise_and"); + break; + case BitwiseBinary::Or: + binary_op(inputs, out, "bitwise_or"); + break; + case BitwiseBinary::Xor: + binary_op(inputs, out, "bitwise_xor"); + break; + case BitwiseBinary::LeftShift: + binary_op(inputs, out, "left_shift"); + break; + case BitwiseBinary::RightShift: + binary_op(inputs, out, "right_shift"); + break; + } +} + +void Divide::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "div"); +} + +void DivMod::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + binary_op(inputs, outputs, "divmod"); +} + +void Remainder::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "rem"); +} + +void Equal::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, equal_nan_ ? "naneq" : "eq"); +} + +void Greater::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "ge"); +} + +void GreaterEqual::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "geq"); +} + +void Less::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "le"); +} + +void LessEqual::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "leq"); +} + +void LogicalAnd::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "land"); +} + +void LogicalOr::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "lor"); +} + +void LogAddExp::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "lae"); +} + +void Maximum::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "max"); +} + +void Minimum::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "min"); +} + +void Multiply::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "mul"); +} + +void NotEqual::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "neq"); +} + +void Power::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "pow"); +} + +void Subtract::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "sub"); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 8b11daa3a..0bfb177a2 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -4,8 +4,8 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" -#include "mlx/backend/metal/compiled_preamble.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" #include "mlx/graph_utils.h" #include "mlx/primitives.h" @@ -190,7 +190,8 @@ void Compiled::eval_gpu( // If not we have to build it ourselves if (lib == nullptr) { std::ostringstream kernel; - kernel << metal::get_kernel_preamble() << std::endl; + kernel << metal::utils() << metal::unary_ops() << metal::binary_ops() + << metal::ternary_ops(); build_kernel( kernel, kernel_lib_ + "_contiguous", diff --git a/mlx/backend/metal/compiled_preamble.h b/mlx/backend/metal/compiled_preamble.h deleted file mode 100644 index 9122d3d54..000000000 --- a/mlx/backend/metal/compiled_preamble.h +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright © 2023-24 Apple Inc. - -#pragma once - -namespace mlx::core::metal { - -const char* get_kernel_preamble(); - -} diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 63ada6c0e..699772950 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -4,12 +4,14 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" namespace mlx::core { +constexpr int MAX_COPY_SPECIALIZED_DIMS = 5; + void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { if (ctype == CopyType::Vector) { // If the input is donateable, we are doing a vector copy and the types @@ -62,27 +64,34 @@ void copy_gpu_inplace( auto& strides_out_ = strides[1]; auto& d = metal::device(s.device); - std::ostringstream kname; - switch (ctype) { - case CopyType::Scalar: - kname << "scopy"; - break; - case CopyType::Vector: - kname << "vcopy"; - break; - case CopyType::General: - kname << "gcopy"; - break; - case CopyType::GeneralGeneral: - kname << "ggcopy"; - break; + std::string kernel_name; + { + std::ostringstream kname; + switch (ctype) { + case CopyType::Scalar: + kname << "s"; + break; + case CopyType::Vector: + kname << "v"; + break; + case CopyType::General: + kname << "g"; + break; + case CopyType::GeneralGeneral: + kname << "gg"; + break; + } + if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) && + shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { + kname << shape.size(); + } + kname << "_copy"; + kname << type_to_name(in) << type_to_name(out); + kernel_name = kname.str(); } - kname << type_to_name(in) << type_to_name(out); - if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) && - shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { - kname << "_" << shape.size(); - } - auto kernel = d.get_kernel(kname.str()); + + auto kernel = get_copy_kernel(d, kernel_name, in, out); + auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); bool donate_in = in.data_shared_ptr() == nullptr; @@ -106,7 +115,7 @@ void copy_gpu_inplace( set_vector_bytes(compute_encoder, strides_out, ndim, 4); } - if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { + if (ndim > MAX_COPY_SPECIALIZED_DIMS) { compute_encoder->setBytes(&ndim, sizeof(int), 5); } diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 479d5dc64..a22d8dd0e 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -285,7 +285,6 @@ MTL::Library* Device::get_library_(const std::string& source_string) { NS::Error* error = nullptr; auto options = MTL::CompileOptions::alloc()->init(); options->setFastMathEnabled(false); - options->setLanguageVersion(get_metal_version()); auto mtl_lib = device_->newLibrary(ns_code, options, &error); options->release(); diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index cb1faf058..0fec4eb70 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -1,24 +1,35 @@ // Copyright © 2023-2024 Apple Inc. -#include -#include -#include -#include +#include -#include "mlx/backend/common/binary.h" +#include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/jit/includes.h" +#include "mlx/backend/metal/jit/indexing.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { -namespace { +constexpr int METAL_MAX_INDEX_ARRAYS = 20; -constexpr int METAL_MAX_INDEX_ARRAYS = 10; - -} // namespace +std::pair make_index_args( + const std::string& idx_type, + int nidx) { + std::ostringstream idx_args; + std::ostringstream idx_arr; + for (int i = 0; i < nidx; ++i) { + idx_args << fmt::format( + "const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i); + idx_arr << fmt::format("idx{0}", i); + if (i < nidx - 1) { + idx_args << "\n"; + idx_arr << ","; + } + } + return {idx_args.str(), idx_arr.str()}; +} void Gather::eval_gpu(const std::vector& inputs, array& out) { auto& src = inputs[0]; @@ -42,15 +53,41 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { int idx_ndim = nidx ? inputs[1].ndim() : 0; size_t ndim = src.ndim(); - std::ostringstream kname; + std::string lib_name; + std::string kernel_name; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; - kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx; - if (idx_ndim <= 1) { - kname << "_" << idx_ndim; + { + std::ostringstream kname; + kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx + << "_" << idx_ndim; + lib_name = kname.str(); + kernel_name = lib_name; + } + + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::gather(); + std::string out_type_str = get_type_string(out.dtype()); + std::string idx_type_str = + nidx ? get_type_string(inputs[1].dtype()) : "bool"; + auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); + + // Index dimension specializations + kernel_source << fmt::format( + gather_kernels, + type_to_name(out) + idx_type_name, + out_type_str, + idx_type_str, + nidx, + idx_args, + idx_arr, + idx_ndim); + lib = d.get_library(lib_name, kernel_source.str()); } auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = d.get_kernel(kernel_name, lib); compute_encoder->setComputePipelineState(kernel); size_t slice_size = 1; @@ -102,8 +139,8 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&idx_ndim, sizeof(int), 9); // Set index buffers - for (int i = 1; i < nidx + 1; ++i) { - compute_encoder.set_input_array(inputs[i], 20 + i); + for (int i = 0; i < nidx; ++i) { + compute_encoder.set_input_array(inputs[i + 1], 20 + i); } // Launch grid @@ -139,10 +176,6 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); - // Get kernel name - std::ostringstream kname; - std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; - int idx_ndim = nidx ? inputs[1].ndim() : 0; bool index_nd1_specialization = (idx_ndim == 1); @@ -159,32 +192,85 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { index_nd1_specialization &= inputs[i].flags().row_contiguous; } - if (index_nd1_specialization) { - kname << "scatter_1d_index" << type_to_name(out) << idx_type_name; - } else { - kname << "scatter" << type_to_name(out) << idx_type_name; - } + std::string lib_name; + std::string kernel_name; + std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; + std::string op_name; switch (reduce_type_) { case Scatter::None: - kname << "_none"; + op_name = "none"; break; case Scatter::Sum: - kname << "_sum"; + op_name = "sum"; break; case Scatter::Prod: - kname << "_prod"; + op_name = "prod"; break; case Scatter::Max: - kname << "_max"; + op_name = "max"; break; case Scatter::Min: - kname << "_min"; + op_name = "min"; break; } - kname << "_" << nidx; + + { + std::ostringstream kname; + if (index_nd1_specialization) { + kname << "scatter_1d_index" << type_to_name(out) << idx_type_name; + } else { + kname << "scatter" << type_to_name(out) << idx_type_name; + } + kname << "_" << op_name << "_" << nidx; + lib_name = kname.str(); + kernel_name = kname.str(); + } + + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::reduction() << metal::scatter(); + + std::string out_type_str = get_type_string(out.dtype()); + std::string idx_type_str = + nidx ? get_type_string(inputs[1].dtype()) : "bool"; + std::string op_type; + switch (reduce_type_) { + case Scatter::None: + op_type = "None"; + break; + case Scatter::Sum: + op_type = "Sum<{0}>"; + break; + case Scatter::Prod: + op_type = "Prod<{0}>"; + break; + case Scatter::Max: + op_type = "Max<{0}>"; + break; + case Scatter::Min: + op_type = "Min<{0}>"; + break; + } + if (reduce_type_ != Scatter::None) { + op_type = fmt::format(op_type, out_type_str); + } + auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); + + kernel_source << fmt::format( + scatter_kernels, + type_to_name(out) + idx_type_name + "_" + op_name, + out_type_str, + idx_type_str, + op_type, + nidx, + idx_args, + idx_arr); + lib = d.get_library(lib_name, kernel_source.str()); + } auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = d.get_kernel(kernel_name, lib); auto& upd = inputs.back(); size_t nthreads = upd.size(); @@ -209,8 +295,8 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&upd_size, sizeof(size_t), 5); // Set index buffers - for (int i = 1; i < nidx + 1; ++i) { - compute_encoder.set_input_array(inputs[i], 20 + i); + for (int i = 0; i < nidx; ++i) { + compute_encoder.set_input_array(inputs[i + 1], 20 + i); } // Launch grid @@ -279,8 +365,8 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&idx_ndim, sizeof(int), 13); // Set index buffers - for (int i = 1; i < nidx + 1; ++i) { - compute_encoder.set_input_array(inputs[i], 20 + i); + for (int i = 0; i < nidx; ++i) { + compute_encoder.set_input_array(inputs[i + 1], 20 + i); } // Launch grid diff --git a/mlx/backend/metal/jit/binary.h b/mlx/backend/metal/jit/binary.h new file mode 100644 index 000000000..febc8be6c --- /dev/null +++ b/mlx/backend/metal/jit/binary.h @@ -0,0 +1,87 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view binary_kernels = R"( +template [[host_name("ss{0}")]] [[kernel]] +void binary_ss<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + uint index [[thread_position_in_grid]]); +template [[host_name("vs{0}")]] [[kernel]] +void binary_vs<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + uint index [[thread_position_in_grid]]); +template [[host_name("sv{0}")]] [[kernel]] +void binary_sv<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + uint index [[thread_position_in_grid]]); +template [[host_name("vv{0}")]] [[kernel]] +void binary_vv<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + uint index [[thread_position_in_grid]]); +template [[host_name("g4{0}")]] [[kernel]] void +binary_g_nd<{1}, {2}, {3}, 4>( + device const {1}* a, + device const {1}* b, + device {2}* c, + constant const int shape[4], + constant const size_t a_strides[4], + constant const size_t b_strides[4], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); +template [[host_name("g5{0}")]] [[kernel]] void +binary_g_nd<{1}, {2}, {3}, 5>( + device const {1}* a, + device const {1}* b, + device {2}* c, + constant const int shape[5], + constant const size_t a_strides[5], + constant const size_t b_strides[5], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); + +template [[host_name("g1{0}")]] [[kernel]] void +binary_g_nd1<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + constant const size_t& a_stride, + constant const size_t& b_stride, + uint index [[thread_position_in_grid]]); +template [[host_name("g2{0}")]] [[kernel]] void +binary_g_nd2<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]); +template [[host_name("g3{0}")]] [[kernel]] void +binary_g_nd3<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); + +template [[host_name("gn{0}")]] [[kernel]] +void binary_g<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); +)"; diff --git a/mlx/backend/metal/jit/binary_two.h b/mlx/backend/metal/jit/binary_two.h new file mode 100644 index 000000000..54b0c6296 --- /dev/null +++ b/mlx/backend/metal/jit/binary_two.h @@ -0,0 +1,98 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view binary_two_kernels = R"( +template [[host_name("ss{0}")]] [[kernel]] +void binary_ss<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + device {2}* d, + uint index [[thread_position_in_grid]]); +template [[host_name("vs{0}")]] [[kernel]] +void binary_vs<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + device {2}* d, + uint index [[thread_position_in_grid]]); +template [[host_name("sv{0}")]] [[kernel]] +void binary_sv<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + device {2}* d, + uint index [[thread_position_in_grid]]); +template [[host_name("vv{0}")]] [[kernel]] +void binary_vv<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + device {2}* d, + uint index [[thread_position_in_grid]]); + +template [[host_name("g4{0}")]] [[kernel]] void +binary_g_nd<{1}, {2}, {3}, 4>( + device const {1}* a, + device const {1}* b, + device {2}* c, + device {2}* d, + constant const int shape[4], + constant const size_t a_strides[4], + constant const size_t b_strides[4], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); +template [[host_name("g5{0}")]] [[kernel]] void +binary_g_nd<{1}, {2}, {3}, 5>( + device const {1}* a, + device const {1}* b, + device {2}* c, + device {2}* d, + constant const int shape[5], + constant const size_t a_strides[5], + constant const size_t b_strides[5], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); + +template [[host_name("g1{0}")]] [[kernel]] void +binary_g_nd1<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + device {2}* d, + constant const size_t& a_stride, + constant const size_t& b_stride, + uint index [[thread_position_in_grid]]); +template [[host_name("g2{0}")]] [[kernel]] void +binary_g_nd2<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + device {2}* d, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]); +template [[host_name("g3{0}")]] [[kernel]] void +binary_g_nd3<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + device {2}* d, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); + +template [[host_name("gn{0}")]] [[kernel]] +void binary_g<{1}, {2}, {3}>( + device const {1}* a, + device const {1}* b, + device {2}* c, + device {2}* d, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); +)"; diff --git a/mlx/backend/metal/jit/copy.h b/mlx/backend/metal/jit/copy.h new file mode 100644 index 000000000..167be8f84 --- /dev/null +++ b/mlx/backend/metal/jit/copy.h @@ -0,0 +1,100 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view copy_kernels = R"( +template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + uint index [[thread_position_in_grid]]); +template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + uint index [[thread_position_in_grid]]); + +template [[host_name("g4_{0}")]] [[kernel]] void +copy_g_nd<{1}, {2}, 4>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); +template [[host_name("gg4_{0}")]] [[kernel]] void +copy_gg_nd<{1}, {2}, 4>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint3 index [[thread_position_in_grid]]); +template [[host_name("g5_{0}")]] [[kernel]] void +copy_g_nd<{1}, {2}, 5>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); +template [[host_name("gg5_{0}")]] [[kernel]] void +copy_gg_nd<{1}, {2}, 5>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint3 index [[thread_position_in_grid]]); +template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + uint index [[thread_position_in_grid]]); +template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]); +template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); +template [[host_name("gg1_{0}")]] [[kernel]] void +copy_gg_nd1<{1}, {2}>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + constant const int64_t& dst_stride [[buffer(4)]], + uint index [[thread_position_in_grid]]); +template [[host_name("gg2_{0}")]] [[kernel]] void +copy_gg_nd2<{1}, {2}>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint2 index [[thread_position_in_grid]]); +template [[host_name("gg3_{0}")]] [[kernel]] void +copy_gg_nd3<{1}, {2}>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint3 index [[thread_position_in_grid]]); + +template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int& ndim [[buffer(5)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); +template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>( + device const {1}* src [[buffer(0)]], + device {2}* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int& ndim [[buffer(5)]], + uint3 index [[thread_position_in_grid]]); +)"; diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h new file mode 100644 index 000000000..e98462374 --- /dev/null +++ b/mlx/backend/metal/jit/includes.h @@ -0,0 +1,21 @@ +// Copyright © 2023-24 Apple Inc. + +#pragma once + +namespace mlx::core::metal { + +const char* utils(); +const char* binary_ops(); +const char* unary_ops(); +const char* ternary_ops(); +const char* reduction(); +const char* gather(); +const char* scatter(); + +const char* unary(); +const char* binary(); +const char* binary_two(); +const char* copy(); +const char* ternary(); + +} // namespace mlx::core::metal diff --git a/mlx/backend/metal/jit/indexing.h b/mlx/backend/metal/jit/indexing.h new file mode 100644 index 000000000..80d2a1e83 --- /dev/null +++ b/mlx/backend/metal/jit/indexing.h @@ -0,0 +1,81 @@ +// Copyright © 2023-2024 Apple Inc. + +constexpr std::string_view gather_kernels = R"( +[[kernel]] void gather{0}_{3}_{6}( + const device {1}* src [[buffer(0)]], + device {1}* out [[buffer(1)]], + const constant int* src_shape [[buffer(2)]], + const constant size_t* src_strides [[buffer(3)]], + const constant size_t& src_ndim [[buffer(4)]], + const constant int* slice_sizes [[buffer(5)]], + const constant int* axes [[buffer(6)]], + const constant int* idx_shapes [[buffer(7)]], + const constant size_t* idx_strides [[buffer(8)]], + const constant int& idx_ndim [[buffer(9)]], + {4} + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) {{ + Indices<{2}, {3}> idxs{{ + {{ {5} }}, idx_shapes, idx_strides, idx_ndim}}; + + return gather_impl<{1}, {2}, {3}, {6}>( + src, + out, + src_shape, + src_strides, + src_ndim, + slice_sizes, + axes, + idxs, + index, + grid_dim); +}} +)"; + +constexpr std::string_view scatter_kernels = R"( +[[kernel]] void scatter_1d_index{0}_{4}( + const device {1}* updates [[buffer(1)]], + device mlx_atomic<{1}>* out [[buffer(2)]], + const constant int* out_shape [[buffer(3)]], + const constant size_t* out_strides [[buffer(4)]], + const constant size_t& upd_size [[buffer(5)]], + {5} + uint2 gid [[thread_position_in_grid]]) {{ + const array idx_buffers = {{ {6} }}; + return scatter_1d_index_impl<{1}, {2}, {3}, {4}>( + updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); +}} + +[[kernel]] void scatter{0}_{4}( + const device {1}* updates [[buffer(1)]], + device mlx_atomic<{1}>* out [[buffer(2)]], + const constant int* upd_shape [[buffer(3)]], + const constant size_t* upd_strides [[buffer(4)]], + const constant size_t& upd_ndim [[buffer(5)]], + const constant size_t& upd_size [[buffer(6)]], + const constant int* out_shape [[buffer(7)]], + const constant size_t* out_strides [[buffer(8)]], + const constant size_t& out_ndim [[buffer(9)]], + const constant int* axes [[buffer(10)]], + const constant int* idx_shapes [[buffer(11)]], + const constant size_t* idx_strides [[buffer(12)]], + const constant int& idx_ndim [[buffer(13)]], + {5} + uint2 gid [[thread_position_in_grid]]) {{ + Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}}; + + return scatter_impl<{1}, {2}, {3}, {4}>( + updates, + out, + upd_shape, + upd_strides, + upd_ndim, + upd_size, + out_shape, + out_strides, + out_ndim, + axes, + idxs, + gid); +}} +)"; diff --git a/mlx/backend/metal/jit/ternary.h b/mlx/backend/metal/jit/ternary.h new file mode 100644 index 000000000..8b49e1311 --- /dev/null +++ b/mlx/backend/metal/jit/ternary.h @@ -0,0 +1,80 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view ternary_kernels = R"( +template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>( + device const bool* a, + device const {1}* b, + device const {1}* c, + device {1}* d, + uint index [[thread_position_in_grid]]); + +template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>( + device const bool* a, + device const {1}* b, + device const {1}* c, + device {1}* d, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); + +template [[host_name("g1_{0}")]] [[kernel]] void +ternary_g_nd1<{1}, {2}>( + device const bool* a, + device const {1}* b, + device const {1}* c, + device {1}* d, + constant const size_t& a_strides, + constant const size_t& b_strides, + constant const size_t& c_strides, + uint index [[thread_position_in_grid]]); +template [[host_name("g2_{0}")]] [[kernel]] void +ternary_g_nd2<{1}, {2}>( + device const bool* a, + device const {1}* b, + device const {1}* c, + device {1}* d, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + constant const size_t c_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]); +template [[host_name("g3_{0}")]] [[kernel]] void +ternary_g_nd3<{1}, {2}>( + device const bool* a, + device const {1}* b, + device const {1}* c, + device {1}* d, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + constant const size_t c_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); +template [[host_name("g4_{0}")]] [[kernel]] void +ternary_g_nd<{1}, {2}, 4>( + device const bool* a, + device const {1}* b, + device const {1}* c, + device {1}* d, + constant const int shape[4], + constant const size_t a_strides[4], + constant const size_t b_strides[4], + constant const size_t c_strides[4], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); +template [[host_name("g5_{0}")]] [[kernel]] void +ternary_g_nd<{1}, {2}, 5>( + device const bool* a, + device const {1}* b, + device const {1}* c, + device {1}* d, + constant const int shape[5], + constant const size_t a_strides[5], + constant const size_t b_strides[5], + constant const size_t c_strides[5], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]); +)"; diff --git a/mlx/backend/metal/jit/unary.h b/mlx/backend/metal/jit/unary.h new file mode 100644 index 000000000..d35957fe4 --- /dev/null +++ b/mlx/backend/metal/jit/unary.h @@ -0,0 +1,16 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view unary_kernels = R"( +template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>( + device const {1}* in, + device {1}* out, + uint index [[thread_position_in_grid]]); + +template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>( + device const {1}* in, + device {1}* out, + device const int* in_shape, + device const size_t* in_strides, + device const int& ndim, + uint index [[thread_position_in_grid]]); +)"; diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp new file mode 100644 index 000000000..60c9c7d2d --- /dev/null +++ b/mlx/backend/metal/jit_kernels.cpp @@ -0,0 +1,124 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/metal/jit/binary.h" +#include "mlx/backend/metal/jit/binary_two.h" +#include "mlx/backend/metal/jit/copy.h" +#include "mlx/backend/metal/jit/includes.h" +#include "mlx/backend/metal/jit/ternary.h" +#include "mlx/backend/metal/jit/unary.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" + +namespace mlx::core { + +std::string op_name(const array& arr) { + std::ostringstream op_t; + arr.primitive().print(op_t); + return op_t.str(); +} + +MTL::ComputePipelineState* get_unary_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out) { + std::string lib_name = kernel_name.substr(1); + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::unary_ops() << metal::unary() + << fmt::format( + unary_kernels, + lib_name, + get_type_string(out.dtype()), + op_name(out)); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_binary_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out) { + std::string lib_name = kernel_name.substr(2); + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::binary_ops() << metal::binary() + << fmt::format( + binary_kernels, + lib_name, + get_type_string(in.dtype()), + get_type_string(out.dtype()), + op_name(out)); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_binary_two_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out) { + std::string lib_name = kernel_name.substr(2); + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::binary_ops() + << metal::binary_two() + << fmt::format( + binary_two_kernels, + lib_name, + get_type_string(in.dtype()), + get_type_string(out.dtype()), + op_name(out)); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_ternary_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out) { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary() + << fmt::format( + ternary_kernels, + lib_name, + get_type_string(out.dtype()), + op_name(out)); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_copy_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out) { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::copy() + << fmt::format( + copy_kernels, + lib_name, + get_type_string(in.dtype()), + get_type_string(out.dtype())); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h new file mode 100644 index 000000000..e7afc04b3 --- /dev/null +++ b/mlx/backend/metal/kernels.h @@ -0,0 +1,36 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/array.h" +#include "mlx/backend/metal/device.h" + +namespace mlx::core { + +MTL::ComputePipelineState* get_unary_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out); + +MTL::ComputePipelineState* get_binary_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out); + +MTL::ComputePipelineState* get_binary_two_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out); + +MTL::ComputePipelineState* get_ternary_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out); + +MTL::ComputePipelineState* get_copy_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out); + +} // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index ec406327e..0e29c0650 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -3,13 +3,8 @@ set( ${CMAKE_CURRENT_SOURCE_DIR}/atomic.h ${CMAKE_CURRENT_SOURCE_DIR}/bf16.h ${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h - ${CMAKE_CURRENT_SOURCE_DIR}/binary.h ${CMAKE_CURRENT_SOURCE_DIR}/complex.h ${CMAKE_CURRENT_SOURCE_DIR}/defines.h - ${CMAKE_CURRENT_SOURCE_DIR}/erf.h - ${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h - ${CMAKE_CURRENT_SOURCE_DIR}/indexing.h - ${CMAKE_CURRENT_SOURCE_DIR}/unary.h ${CMAKE_CURRENT_SOURCE_DIR}/utils.h ) @@ -17,10 +12,7 @@ set( KERNELS "arange" "arg_reduce" - "binary" - "binary_two" "conv" - "copy" "fft" "gemv" "quantized" @@ -32,12 +24,30 @@ set( "scaled_dot_product_attention" "softmax" "sort" - "ternary" - "unary" - "gather" - "scatter" ) +if (NOT MLX_METAL_JIT) +set( + KERNELS + ${KERNELS} + "binary" + "binary_two" + "unary" + "ternary" + "copy" +) +set( + HEADERS + ${HEADERS} + unary_ops.h + unary.h + binary_ops.h + binary.h + ternary.h + copy.h +) +endif() + function(build_kernel_base TARGET SRCFILE DEPS) set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION}) if(MLX_METAL_DEBUG) diff --git a/mlx/backend/metal/kernels/atomic.h b/mlx/backend/metal/kernels/atomic.h index c0f4b9ed8..7ee6ac294 100644 --- a/mlx/backend/metal/kernels/atomic.h +++ b/mlx/backend/metal/kernels/atomic.h @@ -2,9 +2,11 @@ #pragma once +#ifndef MLX_METAL_JIT #include #include #include "mlx/backend/metal/kernels/bf16.h" +#endif using namespace metal; diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index 9eea3c7b8..ca55bdebf 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -1,273 +1,113 @@ -// Copyright © 2023-2024 Apple Inc. +// Copyright © 2024 Apple Inc. -#pragma once +template +[[kernel]] void binary_ss( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[0], b[0]); +} -#include -#include +template +[[kernel]] void binary_sv( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[0], b[index]); +} -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/utils.h" +template +[[kernel]] void binary_vs( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[index], b[0]); +} -struct Add { - template - T operator()(T x, T y) { - return x + y; - } -}; +template +[[kernel]] void binary_vv( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[index], b[index]); +} -struct Divide { - template - T operator()(T x, T y) { - return x / y; - } -}; +template +[[kernel]] void binary_g_nd1( + device const T* a, + device const T* b, + device U* c, + constant const size_t& a_stride, + constant const size_t& b_stride, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); + c[index] = Op()(a[a_idx], b[b_idx]); +} -struct Remainder { - template - metal::enable_if_t & !metal::is_signed_v, T> - operator()(T x, T y) { - return x % y; - } - template - metal::enable_if_t & metal::is_signed_v, T> - operator()(T x, T y) { - auto r = x % y; - if (r != 0 && (r < 0 != y < 0)) { - r += y; - } - return r; - } - template - metal::enable_if_t, T> operator()(T x, T y) { - T r = fmod(x, y); - if (r != 0 && (r < 0 != y < 0)) { - r += y; - } - return r; - } - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - return x % y; - } -}; +template +[[kernel]] void binary_g_nd2( + device const T* a, + device const T* b, + device U* c, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + c[out_idx] = Op()(a[a_idx], b[b_idx]); +} -struct Equal { - template - bool operator()(T x, T y) { - return x == y; - } -}; +template +[[kernel]] void binary_g_nd3( + device const T* a, + device const T* b, + device U* c, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + size_t out_idx = + index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + c[out_idx] = Op()(a[a_idx], b[b_idx]); +} -struct NaNEqual { - template - bool operator()(T x, T y) { - return x == y || (metal::isnan(x) && metal::isnan(y)); - } - template <> - bool operator()(complex64_t x, complex64_t y) { - return x == y || - (metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) && - metal::isnan(y.imag)) || - (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || - (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); - } -}; +template +[[kernel]] void binary_g_nd( + device const T* a, + device const T* b, + device U* c, + constant const int shape[DIM], + constant const size_t a_strides[DIM], + constant const size_t b_strides[DIM], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides); + size_t out_idx = + index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + c[out_idx] = Op()(a[idx.x], b[idx.y]); +} -struct Greater { - template - bool operator()(T x, T y) { - return x > y; - } -}; - -struct GreaterEqual { - template - bool operator()(T x, T y) { - return x >= y; - } -}; - -struct Less { - template - bool operator()(T x, T y) { - return x < y; - } -}; - -struct LessEqual { - template - bool operator()(T x, T y) { - return x <= y; - } -}; - -struct LogAddExp { - template - T operator()(T x, T y) { - if (metal::isnan(x) || metal::isnan(y)) { - return metal::numeric_limits::quiet_NaN(); - } - constexpr T inf = metal::numeric_limits::infinity(); - T maxval = metal::max(x, y); - T minval = metal::min(x, y); - return (minval == -inf || maxval == inf) - ? maxval - : (maxval + log1p(metal::exp(minval - maxval))); - }; -}; - -struct Maximum { - template - metal::enable_if_t, T> operator()(T x, T y) { - return metal::max(x, y); - } - - template - metal::enable_if_t, T> operator()(T x, T y) { - if (metal::isnan(x)) { - return x; - } - return x > y ? x : y; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - if (metal::isnan(x.real) || metal::isnan(x.imag)) { - return x; - } - return x > y ? x : y; - } -}; - -struct Minimum { - template - metal::enable_if_t, T> operator()(T x, T y) { - return metal::min(x, y); - } - - template - metal::enable_if_t, T> operator()(T x, T y) { - if (metal::isnan(x)) { - return x; - } - return x < y ? x : y; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - if (metal::isnan(x.real) || metal::isnan(x.imag)) { - return x; - } - return x < y ? x : y; - } -}; - -struct Multiply { - template - T operator()(T x, T y) { - return x * y; - } -}; - -struct NotEqual { - template - bool operator()(T x, T y) { - return x != y; - } - template <> - bool operator()(complex64_t x, complex64_t y) { - return x.real != y.real || x.imag != y.imag; - } -}; - -struct Power { - template - metal::enable_if_t, T> operator()(T base, T exp) { - return metal::pow(base, exp); - } - - template - metal::enable_if_t, T> operator()(T base, T exp) { - T res = 1; - while (exp) { - if (exp & 1) { - res *= base; - } - exp >>= 1; - base *= base; - } - return res; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - auto x_theta = metal::atan(x.imag / x.real); - auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); - auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); - auto phase = y.imag * x_ln_r + y.real * x_theta; - return {mag * metal::cos(phase), mag * metal::sin(phase)}; - } -}; - -struct Subtract { - template - T operator()(T x, T y) { - return x - y; - } -}; - -struct LogicalAnd { - template - T operator()(T x, T y) { - return x && y; - }; -}; - -struct LogicalOr { - template - T operator()(T x, T y) { - return x || y; - }; -}; - -struct BitwiseAnd { - template - T operator()(T x, T y) { - return x & y; - }; -}; - -struct BitwiseOr { - template - T operator()(T x, T y) { - return x | y; - }; -}; - -struct BitwiseXor { - template - T operator()(T x, T y) { - return x ^ y; - }; -}; - -struct LeftShift { - template - T operator()(T x, T y) { - return x << y; - }; -}; - -struct RightShift { - template - T operator()(T x, T y) { - return x >> y; - }; -}; - -struct ArcTan2 { - template - T operator()(T y, T x) { - return metal::precise::atan2(y, x); - } -}; +template +[[kernel]] void binary_g( + device const T* a, + device const T* b, + device U* c, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); + size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); + c[out_idx] = Op()(a[idx.x], b[idx.y]); +} diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 67967f928..f85f96a53 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -1,130 +1,24 @@ -// Copyright © 2023-2024 Apple Inc. +// Copyright © 2024 Apple Inc. +#include +#include + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary.h" -template -[[kernel]] void binary_op_ss( - device const T* a, - device const T* b, - device U* c, - uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[0]); -} - -template -[[kernel]] void binary_op_sv( - device const T* a, - device const T* b, - device U* c, - uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[index]); -} - -template -[[kernel]] void binary_op_vs( - device const T* a, - device const T* b, - device U* c, - uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[0]); -} - -template -[[kernel]] void binary_op_vv( - device const T* a, - device const T* b, - device U* c, - uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[index]); -} - -template -[[kernel]] void binary_op_g_nd1( - device const T* a, - device const T* b, - device U* c, - constant const size_t& a_stride, - constant const size_t& b_stride, - uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); - c[index] = Op()(a[a_idx], b[b_idx]); -} - -template -[[kernel]] void binary_op_g_nd2( - device const T* a, - device const T* b, - device U* c, - constant const size_t a_strides[2], - constant const size_t b_strides[2], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * index.y; - c[out_idx] = Op()(a[a_idx], b[b_idx]); -} - -template -[[kernel]] void binary_op_g_nd3( - device const T* a, - device const T* b, - device U* c, - constant const size_t a_strides[3], - constant const size_t b_strides[3], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); - c[out_idx] = Op()(a[a_idx], b[b_idx]); -} - -template -[[kernel]] void binary_op_g_nd( - device const T* a, - device const T* b, - device U* c, - constant const int shape[DIM], - constant const size_t a_strides[DIM], - constant const size_t b_strides[DIM], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides); - size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); - c[out_idx] = Op()(a[idx.x], b[idx.y]); -} - -template -[[kernel]] void binary_op_g( - device const T* a, - device const T* b, - device U* c, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - constant const int& ndim, - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); - size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); - c[out_idx] = Op()(a[idx.x], b[idx.y]); -} - #define instantiate_binary(name, itype, otype, op, bopt) \ template \ - [[host_name(name)]] [[kernel]] void binary_op_##bopt( \ + [[host_name(name)]] [[kernel]] void binary_##bopt( \ device const itype* a, \ device const itype* b, \ device otype* c, \ uint index [[thread_position_in_grid]]); #define instantiate_binary_g_dim(name, itype, otype, op, dims) \ - template [[host_name(name "_" #dims)]] [[kernel]] void \ - binary_op_g_nd( \ + template [[host_name("g" #dims name)]] [[kernel]] void \ + binary_g_nd( \ device const itype* a, \ device const itype* b, \ device otype* c, \ @@ -135,16 +29,16 @@ template uint3 grid_dim [[threads_per_grid]]); #define instantiate_binary_g_nd(name, itype, otype, op) \ - template [[host_name(name "_1")]] [[kernel]] void \ - binary_op_g_nd1( \ + template [[host_name("g1" name)]] [[kernel]] void \ + binary_g_nd1( \ device const itype* a, \ device const itype* b, \ device otype* c, \ constant const size_t& a_stride, \ constant const size_t& b_stride, \ uint index [[thread_position_in_grid]]); \ - template [[host_name(name "_2")]] [[kernel]] void \ - binary_op_g_nd2( \ + template [[host_name("g2" name)]] [[kernel]] void \ + binary_g_nd2( \ device const itype* a, \ device const itype* b, \ device otype* c, \ @@ -152,8 +46,8 @@ template constant const size_t b_strides[2], \ uint2 index [[thread_position_in_grid]], \ uint2 grid_dim [[threads_per_grid]]); \ - template [[host_name(name "_3")]] [[kernel]] void \ - binary_op_g_nd3( \ + template [[host_name("g3" name)]] [[kernel]] void \ + binary_g_nd3( \ device const itype* a, \ device const itype* b, \ device otype* c, \ @@ -162,30 +56,28 @@ template uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); \ instantiate_binary_g_dim(name, itype, otype, op, 4) \ - instantiate_binary_g_dim(name, itype, otype, op, 5) + instantiate_binary_g_dim(name, itype, otype, op, 5) -#define instantiate_binary_g(name, itype, otype, op) \ - template [[host_name(name)]] [[kernel]] void binary_op_g( \ - device const itype* a, \ - device const itype* b, \ - device otype* c, \ - constant const int* shape, \ - constant const size_t* a_strides, \ - constant const size_t* b_strides, \ - constant const int& ndim, \ - uint3 index [[thread_position_in_grid]], \ +#define instantiate_binary_g(name, itype, otype, op) \ + template [[host_name("gn" name)]] [[kernel]] void binary_g( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + constant const int* shape, \ + constant const size_t* a_strides, \ + constant const size_t* b_strides, \ + constant const int& ndim, \ + uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); -// clang-format off #define instantiate_binary_all(name, tname, itype, otype, op) \ instantiate_binary("ss" #name #tname, itype, otype, op, ss) \ instantiate_binary("sv" #name #tname, itype, otype, op, sv) \ instantiate_binary("vs" #name #tname, itype, otype, op, vs) \ instantiate_binary("vv" #name #tname, itype, otype, op, vv) \ - instantiate_binary_g("g" #name #tname, itype, otype, op) \ - instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on + instantiate_binary_g(#name #tname, itype, otype, op) \ + instantiate_binary_g_nd(#name #tname, itype, otype, op) -// clang-format off #define instantiate_binary_integer(name, op) \ instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \ instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \ @@ -194,22 +86,19 @@ template instantiate_binary_all(name, int8, int8_t, int8_t, op) \ instantiate_binary_all(name, int16, int16_t, int16_t, op) \ instantiate_binary_all(name, int32, int32_t, int32_t, op) \ - instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on + instantiate_binary_all(name, int64, int64_t, int64_t, op) -// clang-format off #define instantiate_binary_float(name, op) \ instantiate_binary_all(name, float16, half, half, op) \ instantiate_binary_all(name, float32, float, float, op) \ - instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on + instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) -// clang-format off #define instantiate_binary_types(name, op) \ instantiate_binary_all(name, bool_, bool, bool, op) \ instantiate_binary_integer(name, op) \ instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \ - instantiate_binary_float(name, op) // clang-format on + instantiate_binary_float(name, op) -// clang-format off #define instantiate_binary_types_bool(name, op) \ instantiate_binary_all(name, bool_, bool, bool, op) \ instantiate_binary_all(name, uint8, uint8_t, bool, op) \ @@ -223,9 +112,8 @@ template instantiate_binary_all(name, float16, half, bool, op) \ instantiate_binary_all(name, float32, float, bool, op) \ instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \ - instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on + instantiate_binary_all(name, complex64, complex64_t, bool, op) -// clang-format off instantiate_binary_types(add, Add) instantiate_binary_types(div, Divide) instantiate_binary_types_bool(eq, Equal) diff --git a/mlx/backend/metal/kernels/binary_ops.h b/mlx/backend/metal/kernels/binary_ops.h new file mode 100644 index 000000000..9cd2126cd --- /dev/null +++ b/mlx/backend/metal/kernels/binary_ops.h @@ -0,0 +1,296 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +struct Add { + template + T operator()(T x, T y) { + return x + y; + } +}; + +struct FloorDivide { + template + T operator()(T x, T y) { + return x / y; + } + template <> + float operator()(float x, float y) { + return trunc(x / y); + } + template <> + half operator()(half x, half y) { + return trunc(x / y); + } + template <> + bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { + return trunc(x / y); + } +}; + +struct Divide { + template + T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + metal::enable_if_t & !metal::is_signed_v, T> + operator()(T x, T y) { + return x % y; + } + template + metal::enable_if_t & metal::is_signed_v, T> + operator()(T x, T y) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template + metal::enable_if_t, T> operator()(T x, T y) { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + return x % y; + } +}; + +struct Equal { + template + bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + bool operator()(T x, T y) { + return x == y || (metal::isnan(x) && metal::isnan(y)); + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x == y || + (metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) && + metal::isnan(y.imag)) || + (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || + (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); + } +}; + +struct Greater { + template + bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + if (metal::isnan(x) || metal::isnan(y)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr T inf = metal::numeric_limits::infinity(); + T maxval = metal::max(x, y); + T minval = metal::min(x, y); + return (minval == -inf || maxval == inf) + ? maxval + : (maxval + log1p(metal::exp(minval - maxval))); + }; +}; + +struct Maximum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::max(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x > y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x > y ? x : y; + } +}; + +struct Minimum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::min(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x < y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x < y ? x : y; + } +}; + +struct Multiply { + template + T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + bool operator()(T x, T y) { + return x != y; + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x.real != y.real || x.imag != y.imag; + } +}; + +struct Power { + template + metal::enable_if_t, T> operator()(T base, T exp) { + return metal::pow(base, exp); + } + + template + metal::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + auto x_theta = metal::atan(x.imag / x.real); + auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); + auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); + auto phase = y.imag * x_ln_r + y.real * x_theta; + return {mag * metal::cos(phase), mag * metal::sin(phase)}; + } +}; + +struct Subtract { + template + T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { + return x || y; + }; +}; + +struct BitwiseAnd { + template + T operator()(T x, T y) { + return x & y; + }; +}; + +struct BitwiseOr { + template + T operator()(T x, T y) { + return x | y; + }; +}; + +struct BitwiseXor { + template + T operator()(T x, T y) { + return x ^ y; + }; +}; + +struct LeftShift { + template + T operator()(T x, T y) { + return x << y; + }; +}; + +struct RightShift { + template + T operator()(T x, T y) { + return x >> y; + }; +}; + +struct ArcTan2 { + template + T operator()(T y, T x) { + return metal::precise::atan2(y, x); + } +}; + +struct DivMod { + template + metal::array operator()(T x, T y) { + return {FloorDivide{}(x, y), Remainder{}(x, y)}; + }; +}; diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h new file mode 100644 index 000000000..3890adbce --- /dev/null +++ b/mlx/backend/metal/kernels/binary_two.h @@ -0,0 +1,140 @@ +// Copyright © 2024 Apple Inc. + +template +[[kernel]] void binary_ss( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + auto out = Op()(a[0], b[0]); + c[index] = out[0]; + d[index] = out[1]; +} + +template +[[kernel]] void binary_sv( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + auto out = Op()(a[0], b[index]); + c[index] = out[0]; + d[index] = out[1]; +} + +template +[[kernel]] void binary_vs( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + auto out = Op()(a[index], b[0]); + c[index] = out[0]; + d[index] = out[1]; +} + +template +[[kernel]] void binary_vv( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + auto out = Op()(a[index], b[index]); + c[index] = out[0]; + d[index] = out[1]; +} + +template +[[kernel]] void binary_g_nd1( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const size_t& a_stride, + constant const size_t& b_stride, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); + auto out = Op()(a[a_idx], b[b_idx]); + c[index] = out[0]; + d[index] = out[1]; +} + +template +[[kernel]] void binary_g_nd2( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + auto out = Op()(a[a_idx], b[b_idx]); + c[out_idx] = out[0]; + d[out_idx] = out[1]; +} + +template +[[kernel]] void binary_g_nd3( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + size_t out_idx = + index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + auto out = Op()(a[a_idx], b[b_idx]); + c[out_idx] = out[0]; + d[out_idx] = out[1]; +} + +template +[[kernel]] void binary_g_nd( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int shape[DIM], + constant const size_t a_strides[DIM], + constant const size_t b_strides[DIM], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides); + size_t out_idx = + index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + auto out = Op()(a[idx.x], b[idx.y]); + c[out_idx] = out[0]; + d[out_idx] = out[1]; +} + +template +[[kernel]] void binary_g( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); + size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); + auto out = Op()(a[idx.x], b[idx.y]); + c[out_idx] = out[0]; + d[out_idx] = out[1]; +} diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index c192561d7..0f63227d9 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -1,212 +1,24 @@ -// Copyright © 2023 Apple Inc. - +// Copyright © 2024 Apple Inc. #include #include -#include "mlx/backend/metal/kernels/bf16.h" +// clang-format off #include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/binary_ops.h" +#include "mlx/backend/metal/kernels/binary_two.h" -struct FloorDivide { - template - T operator()(T x, T y) { - return x / y; - } - template <> - float operator()(float x, float y) { - return trunc(x / y); - } - template <> - half operator()(half x, half y) { - return trunc(x / y); - } - template <> - bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { - return trunc(x / y); - } -}; - -struct Remainder { - template - metal::enable_if_t & !metal::is_signed_v, T> - operator()(T x, T y) { - return x % y; - } - template - metal::enable_if_t & metal::is_signed_v, T> - operator()(T x, T y) { - auto r = x % y; - if (r != 0 && (r < 0 != y < 0)) { - r += y; - } - return r; - } - template - metal::enable_if_t, T> operator()(T x, T y) { - T r = fmod(x, y); - if (r != 0 && (r < 0 != y < 0)) { - r += y; - } - return r; - } - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - return x % y; - } -}; - -template -[[kernel]] void binary_op_s2s( - device const T* a, - device const T* b, - device U* c, - device U* d, - uint index [[thread_position_in_grid]]) { - c[index] = Op1()(a[0], b[0]); - d[index] = Op2()(a[0], b[0]); -} - -template -[[kernel]] void binary_op_ss( - device const T* a, - device const T* b, - device U* c, - device U* d, - uint index [[thread_position_in_grid]]) { - c[index] = Op1()(a[0], b[0]); - d[index] = Op2()(a[0], b[0]); -} - -template -[[kernel]] void binary_op_sv( - device const T* a, - device const T* b, - device U* c, - device U* d, - uint index [[thread_position_in_grid]]) { - c[index] = Op1()(a[0], b[index]); - d[index] = Op2()(a[0], b[index]); -} - -template -[[kernel]] void binary_op_vs( - device const T* a, - device const T* b, - device U* c, - device U* d, - uint index [[thread_position_in_grid]]) { - c[index] = Op1()(a[index], b[0]); - d[index] = Op2()(a[index], b[0]); -} - -template -[[kernel]] void binary_op_vv( - device const T* a, - device const T* b, - device U* c, - device U* d, - uint index [[thread_position_in_grid]]) { - c[index] = Op1()(a[index], b[index]); - d[index] = Op2()(a[index], b[index]); -} - -template -[[kernel]] void binary_op_g_nd1( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant const size_t& a_stride, - constant const size_t& b_stride, - uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); - c[index] = Op1()(a[a_idx], b[b_idx]); - d[index] = Op2()(a[a_idx], b[b_idx]); -} - -template -[[kernel]] void binary_op_g_nd2( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant const size_t a_strides[2], - constant const size_t b_strides[2], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * index.y; - c[out_idx] = Op1()(a[a_idx], b[b_idx]); - d[out_idx] = Op2()(a[a_idx], b[b_idx]); -} - -template -[[kernel]] void binary_op_g_nd3( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant const size_t a_strides[3], - constant const size_t b_strides[3], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); - c[out_idx] = Op1()(a[a_idx], b[b_idx]); - d[out_idx] = Op2()(a[a_idx], b[b_idx]); -} - -template -[[kernel]] void binary_op_g_nd( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant const int shape[DIM], - constant const size_t a_strides[DIM], - constant const size_t b_strides[DIM], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides); - size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); - c[out_idx] = Op1()(a[idx.x], b[idx.y]); - d[out_idx] = Op2()(a[idx.x], b[idx.y]); -} - -template -[[kernel]] void binary_op_g( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - constant const int& ndim, - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); - size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); - c[out_idx] = Op1()(a[idx.x], b[idx.y]); - d[out_idx] = Op2()(a[idx.x], b[idx.y]); -} - -#define instantiate_binary(name, itype, otype, op1, op2, bopt) \ +#define instantiate_binary(name, itype, otype, op, bopt) \ template [[host_name(name)]] [[kernel]] void \ - binary_op_##bopt( \ + binary_##bopt( \ device const itype* a, \ device const itype* b, \ device otype* c, \ device otype* d, \ uint index [[thread_position_in_grid]]); -#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \ - template [[host_name(name "_" #dims)]] [[kernel]] void \ - binary_op_g_nd( \ +#define instantiate_binary_g_dim(name, itype, otype, op, dims) \ + template [[host_name("g" #dims name)]] [[kernel]] void \ + binary_g_nd( \ device const itype* a, \ device const itype* b, \ device otype* c, \ @@ -217,10 +29,9 @@ template uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); -// clang-format off -#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \ - template [[host_name(name "_1")]] [[kernel]] void \ - binary_op_g_nd1( \ +#define instantiate_binary_g_nd(name, itype, otype, op) \ + template [[host_name("g1" name)]] [[kernel]] void \ + binary_g_nd1( \ device const itype* a, \ device const itype* b, \ device otype* c, \ @@ -228,8 +39,8 @@ template constant const size_t& a_stride, \ constant const size_t& b_stride, \ uint index [[thread_position_in_grid]]); \ - template [[host_name(name "_2")]] [[kernel]] void \ - binary_op_g_nd2( \ + template [[host_name("g2" name)]] [[kernel]] void \ + binary_g_nd2( \ device const itype* a, \ device const itype* b, \ device otype* c, \ @@ -238,8 +49,8 @@ template constant const size_t b_strides[2], \ uint2 index [[thread_position_in_grid]], \ uint2 grid_dim [[threads_per_grid]]); \ - template [[host_name(name "_3")]] [[kernel]] void \ - binary_op_g_nd3( \ + template [[host_name("g3" name)]] [[kernel]] void \ + binary_g_nd3( \ device const itype* a, \ device const itype* b, \ device otype* c, \ @@ -248,12 +59,12 @@ template constant const size_t b_strides[3], \ uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); \ - instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \ - instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on + instantiate_binary_g_dim(name, itype, otype, op, 4) \ + instantiate_binary_g_dim(name, itype, otype, op, 5) -#define instantiate_binary_g(name, itype, otype, op1, op2) \ - template [[host_name(name)]] [[kernel]] void \ - binary_op_g( \ +#define instantiate_binary_g(name, itype, otype, op) \ + template [[host_name("gn" name)]] [[kernel]] void \ + binary_g( \ device const itype* a, \ device const itype* b, \ device otype* c, \ @@ -265,33 +76,30 @@ template uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); -// clang-format off -#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \ - instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \ - instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \ - instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \ - instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \ - instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \ - instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on +#define instantiate_binary_all(name, tname, itype, otype, op) \ + instantiate_binary("ss" #name #tname, itype, otype, op, ss) \ + instantiate_binary("sv" #name #tname, itype, otype, op, sv) \ + instantiate_binary("vs" #name #tname, itype, otype, op, vs) \ + instantiate_binary("vv" #name #tname, itype, otype, op, vv) \ + instantiate_binary_g(#name #tname, itype, otype, op) \ + instantiate_binary_g_nd(#name #tname, itype, otype, op) -// clang-format off -#define instantiate_binary_float(name, op1, op2) \ - instantiate_binary_all(name, float16, half, half, op1, op2) \ - instantiate_binary_all(name, float32, float, float, op1, op2) \ - instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on +#define instantiate_binary_float(name, op) \ + instantiate_binary_all(name, float16, half, half, op) \ + instantiate_binary_all(name, float32, float, float, op) \ + instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) -// clang-format off -#define instantiate_binary_types(name, op1, op2) \ - instantiate_binary_all(name, bool_, bool, bool, op1, op2) \ - instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \ - instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \ - instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \ - instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \ - instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \ - instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \ - instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \ - instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \ - instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \ - instantiate_binary_float(name, op1, op2) +#define instantiate_binary_types(name, op) \ + instantiate_binary_all(name, bool_, bool, bool, op) \ + instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \ + instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \ + instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \ + instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \ + instantiate_binary_all(name, int8, int8_t, int8_t, op) \ + instantiate_binary_all(name, int16, int16_t, int16_t, op) \ + instantiate_binary_all(name, int32, int32_t, int32_t, op) \ + instantiate_binary_all(name, int64, int64_t, int64_t, op) \ + instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \ + instantiate_binary_float(name, op) -instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on +instantiate_binary_types(divmod, DivMod) // clang-format on diff --git a/mlx/backend/metal/kernels/compiled_preamble.h b/mlx/backend/metal/kernels/compiled_preamble.h deleted file mode 100644 index 12fdc8117..000000000 --- a/mlx/backend/metal/kernels/compiled_preamble.h +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include "mlx/backend/metal/kernels/binary.h" -#include "mlx/backend/metal/kernels/ternary.h" -#include "mlx/backend/metal/kernels/unary.h" - -typedef half float16_t; diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h new file mode 100644 index 000000000..451b7bb4c --- /dev/null +++ b/mlx/backend/metal/kernels/copy.h @@ -0,0 +1,144 @@ +// Copyright © 2024 Apple Inc. + +template +[[kernel]] void copy_s( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + dst[index] = static_cast(src[0]); +} + +template +[[kernel]] void copy_v( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + dst[index] = static_cast(src[index]); +} + +template +[[kernel]] void copy_g_nd1( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + dst[index] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g_nd2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y; + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g_nd3( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + int64_t dst_idx = + index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g_nd( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); + int64_t dst_idx = + index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int& ndim [[buffer(5)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); + int64_t dst_idx = + index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd1( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + constant const int64_t& dst_stride [[buffer(4)]], + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint2 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + auto dst_idx = elem_to_loc_2(index, dst_strides); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd3( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint3 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + auto dst_idx = elem_to_loc_3(index, dst_strides); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint3 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); + auto dst_idx = elem_to_loc_nd(index, src_shape, dst_strides); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int& ndim [[buffer(5)]], + uint3 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); + auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim); + dst[dst_idx] = static_cast(src[src_idx]); +} diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index 01518144b..df21c75e0 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -1,150 +1,9 @@ -// Copyright © 2023-2024 Apple Inc. +// Copyright © 2024 Apple Inc. -#include "mlx/backend/metal/kernels/bf16.h" +// clang-format off #include "mlx/backend/metal/kernels/utils.h" - -template -[[kernel]] void copy_s( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[0]); -} - -template -[[kernel]] void copy_v( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[index]); -} - -template -[[kernel]] void copy_g_nd1( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t& src_stride [[buffer(3)]], - uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); - dst[index] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_g_nd2( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); - int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y; - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_g_nd3( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); - int64_t dst_idx = - index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_g_nd( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); - int64_t dst_idx = - index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_g( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int& ndim [[buffer(5)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); - int64_t dst_idx = - index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_gg_nd1( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t& src_stride [[buffer(3)]], - constant const int64_t& dst_stride [[buffer(4)]], - uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); - auto dst_idx = elem_to_loc_1(index, dst_stride); - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_gg_nd2( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - uint2 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); - auto dst_idx = elem_to_loc_2(index, dst_strides); - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_gg_nd3( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); - auto dst_idx = elem_to_loc_3(index, dst_strides); - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_gg_nd( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); - auto dst_idx = elem_to_loc_nd(index, src_shape, dst_strides); - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_gg( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - constant const int& ndim [[buffer(5)]], - uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); - auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim); - dst[dst_idx] = static_cast(src[src_idx]); -} +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/copy.h" #define instantiate_copy(name, itype, otype, ctype) \ template [[host_name(name)]] [[kernel]] void copy_##ctype( \ @@ -152,92 +11,90 @@ template device otype* dst [[buffer(1)]], \ uint index [[thread_position_in_grid]]); -#define instantiate_copy_g_dim(name, itype, otype, dims) \ - template [[host_name(name "_" #dims)]] [[kernel]] void \ - copy_g_nd( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int* src_shape [[buffer(2)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - template [[host_name("g" name "_" #dims)]] [[kernel]] void \ - copy_gg_nd( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int* src_shape [[buffer(2)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int64_t* dst_strides [[buffer(4)]], \ +#define instantiate_copy_g_dim(name, itype, otype, dims) \ + template [[host_name("g" #dims "_" name)]] [[kernel]] void \ + copy_g_nd( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int* src_shape [[buffer(2)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + template [[host_name("gg" #dims "_" name)]] [[kernel]] void \ + copy_gg_nd( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int* src_shape [[buffer(2)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int64_t* dst_strides [[buffer(4)]], \ uint3 index [[thread_position_in_grid]]); -#define instantiate_copy_g_nd(name, itype, otype) \ - template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t& src_stride [[buffer(3)]], \ - uint index [[thread_position_in_grid]]); \ - template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ - template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - template [[host_name("g" name "_1")]] [[kernel]] void \ - copy_gg_nd1( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t& src_stride [[buffer(3)]], \ - constant const int64_t& dst_stride [[buffer(4)]], \ - uint index [[thread_position_in_grid]]); \ - template [[host_name("g" name "_2")]] [[kernel]] void \ - copy_gg_nd2( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int64_t* dst_strides [[buffer(4)]], \ - uint2 index [[thread_position_in_grid]]); \ - template [[host_name("g" name "_3")]] [[kernel]] void \ - copy_gg_nd3( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int64_t* dst_strides [[buffer(4)]], \ - uint3 index [[thread_position_in_grid]]); \ - instantiate_copy_g_dim(name, itype, otype, 4) \ - instantiate_copy_g_dim(name, itype, otype, 5) +#define instantiate_copy_g_nd(name, itype, otype) \ + template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t& src_stride [[buffer(3)]], \ + uint index [[thread_position_in_grid]]); \ + template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); \ + template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + template [[host_name("gg1_" name )]] [[kernel]] void \ + copy_gg_nd1( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t& src_stride [[buffer(3)]], \ + constant const int64_t& dst_stride [[buffer(4)]], \ + uint index [[thread_position_in_grid]]); \ + template [[host_name("gg2_" name)]] [[kernel]] void \ + copy_gg_nd2( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int64_t* dst_strides [[buffer(4)]], \ + uint2 index [[thread_position_in_grid]]); \ + template [[host_name("gg3_" name)]] [[kernel]] void \ + copy_gg_nd3( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int64_t* dst_strides [[buffer(4)]], \ + uint3 index [[thread_position_in_grid]]); \ + instantiate_copy_g_dim(name, itype, otype, 4) \ + instantiate_copy_g_dim(name, itype, otype, 5) -#define instantiate_copy_g(name, itype, otype) \ - template [[host_name(name)]] [[kernel]] void copy_g( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int* src_shape [[buffer(2)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int& ndim [[buffer(5)]], \ - uint3 index [[thread_position_in_grid]], \ - uint3 grid_dim [[threads_per_grid]]); \ - template [[host_name("g" name)]] [[kernel]] void copy_gg( \ - device const itype* src [[buffer(0)]], \ - device otype* dst [[buffer(1)]], \ - constant const int* src_shape [[buffer(2)]], \ - constant const int64_t* src_strides [[buffer(3)]], \ - constant const int64_t* dst_strides [[buffer(4)]], \ - constant const int& ndim [[buffer(5)]], \ +#define instantiate_copy_g(name, itype, otype) \ + template [[host_name("g_" name)]] [[kernel]] void copy_g( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int* src_shape [[buffer(2)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int& ndim [[buffer(5)]], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + template [[host_name("gg_" name)]] [[kernel]] void copy_gg( \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int* src_shape [[buffer(2)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int64_t* dst_strides [[buffer(4)]], \ + constant const int& ndim [[buffer(5)]], \ uint3 index [[thread_position_in_grid]]); -// clang-format off -#define instantiate_copy_all(tname, itype, otype) \ - instantiate_copy("scopy" #tname, itype, otype, s) \ - instantiate_copy("vcopy" #tname, itype, otype, v) \ - instantiate_copy_g("gcopy" #tname, itype, otype) \ - instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on +#define instantiate_copy_all(tname, itype, otype) \ + instantiate_copy("s_copy" #tname, itype, otype, s) \ + instantiate_copy("v_copy" #tname, itype, otype, v) \ + instantiate_copy_g("copy" #tname, itype, otype) \ + instantiate_copy_g_nd("copy" #tname, itype, otype) -// clang-format off #define instantiate_copy_itype(itname, itype) \ instantiate_copy_all(itname ##bool_, itype, bool) \ instantiate_copy_all(itname ##uint8, itype, uint8_t) \ diff --git a/mlx/backend/metal/kernels/defines.h b/mlx/backend/metal/kernels/defines.h index 9e62b7c32..7f7bb49ed 100644 --- a/mlx/backend/metal/kernels/defines.h +++ b/mlx/backend/metal/kernels/defines.h @@ -8,8 +8,6 @@ #define MTL_CONST #endif -static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5; -static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5; static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; static MTL_CONST constexpr int REDUCE_N_READS = 16; static MTL_CONST constexpr int SOFTMAX_N_READS = 4; diff --git a/mlx/backend/metal/kernels/erf.h b/mlx/backend/metal/kernels/erf.h index 0a370d304..da6c2eacd 100644 --- a/mlx/backend/metal/kernels/erf.h +++ b/mlx/backend/metal/kernels/erf.h @@ -1,7 +1,6 @@ // Copyright © 2023 Apple Inc. #pragma once - #include /* @@ -67,4 +66,4 @@ float erfinv(float a) { p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 } return a * p; -} \ No newline at end of file +} diff --git a/mlx/backend/metal/kernels/gather.h b/mlx/backend/metal/kernels/gather.h new file mode 100644 index 000000000..34f807f3d --- /dev/null +++ b/mlx/backend/metal/kernels/gather.h @@ -0,0 +1,45 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/indexing.h" + +template +METAL_FUNC void gather_impl( + const device T* src [[buffer(0)]], + device T* out [[buffer(1)]], + const constant int* src_shape [[buffer(2)]], + const constant size_t* src_strides [[buffer(3)]], + const constant size_t& src_ndim [[buffer(4)]], + const constant int* slice_sizes [[buffer(5)]], + const constant int* axes [[buffer(6)]], + const thread Indices& indices, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto ind_idx = index.x; + auto ind_offset = index.y; + + size_t src_idx = 0; + for (int i = 0; i < NIDX; ++i) { + size_t idx_loc; + if (IDX_NDIM == 0) { + idx_loc = 0; + } else if (IDX_NDIM == 1) { + idx_loc = ind_idx * indices.strides[indices.ndim * i]; + } else { + idx_loc = elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + } + auto ax = axes[i]; + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); + src_idx += idx_val * src_strides[ax]; + } + + auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim); + + size_t out_idx = index.y + static_cast(grid_dim.y) * index.x; + out[out_idx] = src[src_offset + src_idx]; +} diff --git a/mlx/backend/metal/kernels/gather.metal b/mlx/backend/metal/kernels/gather.metal deleted file mode 100644 index f8e4fbb87..000000000 --- a/mlx/backend/metal/kernels/gather.metal +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include - -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/indexing.h" -#include "mlx/backend/metal/kernels/utils.h" - -using namespace metal; - -///////////////////////////////////////////////////////////////////// -// Gather kernel -///////////////////////////////////////////////////////////////////// - -template -METAL_FUNC void gather_impl( - const device T* src [[buffer(0)]], - device T* out [[buffer(1)]], - const constant int* src_shape [[buffer(2)]], - const constant size_t* src_strides [[buffer(3)]], - const constant size_t& src_ndim [[buffer(4)]], - const constant int* slice_sizes [[buffer(5)]], - const constant int* axes [[buffer(6)]], - const thread Indices& indices, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - auto ind_idx = index.x; - auto ind_offset = index.y; - - size_t src_idx = 0; - for (int i = 0; i < NIDX; ++i) { - size_t idx_loc; - if (IDX_NDIM == 0) { - idx_loc = 0; - } else if (IDX_NDIM == 1) { - idx_loc = ind_idx * indices.strides[indices.ndim * i]; - } else { - idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); - } - auto ax = axes[i]; - auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); - src_idx += idx_val * src_strides[ax]; - } - - auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim); - - size_t out_idx = index.y + static_cast(grid_dim.y) * index.x; - out[out_idx] = src[src_offset + src_idx]; -} - -#define make_gather_impl(IDX_ARG, IDX_ARR) \ - template \ - [[kernel]] void gather( \ - const device T* src [[buffer(0)]], \ - device T* out [[buffer(1)]], \ - const constant int* src_shape [[buffer(2)]], \ - const constant size_t* src_strides [[buffer(3)]], \ - const constant size_t& src_ndim [[buffer(4)]], \ - const constant int* slice_sizes [[buffer(5)]], \ - const constant int* axes [[buffer(6)]], \ - const constant int* idx_shapes [[buffer(7)]], \ - const constant size_t* idx_strides [[buffer(8)]], \ - const constant int& idx_ndim [[buffer(9)]], \ - IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]) { \ - Indices idxs{ \ - {{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \ - \ - return gather_impl( \ - src, \ - out, \ - src_shape, \ - src_strides, \ - src_ndim, \ - slice_sizes, \ - axes, \ - idxs, \ - index, \ - grid_dim); \ - } - -#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n) - -make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4) - make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9) - make_gather(10) - -///////////////////////////////////////////////////////////////////// -// Gather instantiations -///////////////////////////////////////////////////////////////////// - -#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \ - template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \ - gather( \ - const device src_t* src [[buffer(0)]], \ - device src_t* out [[buffer(1)]], \ - const constant int* src_shape [[buffer(2)]], \ - const constant size_t* src_strides [[buffer(3)]], \ - const constant size_t& src_ndim [[buffer(4)]], \ - const constant int* slice_sizes [[buffer(5)]], \ - const constant int* axes [[buffer(6)]], \ - const constant int* idx_shapes [[buffer(7)]], \ - const constant size_t* idx_strides [[buffer(8)]], \ - const constant int& idx_ndim [[buffer(9)]], \ - IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); - -// clang-format off -#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \ - instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on - -// clang-format off -#define instantiate_gather4(name, src_t, idx_t, nidx) \ - instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \ - instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \ - instantiate_gather5(name, src_t, idx_t, nidx, 2, ) - - -// Special for case NIDX=0 -instantiate_gather4("bool_", bool, bool, 0) -instantiate_gather4("uint8", uint8_t, bool, 0) -instantiate_gather4("uint16", uint16_t, bool, 0) -instantiate_gather4("uint32", uint32_t, bool, 0) -instantiate_gather4("uint64", uint64_t, bool, 0) -instantiate_gather4("int8", int8_t, bool, 0) -instantiate_gather4("int16", int16_t, bool, 0) -instantiate_gather4("int32", int32_t, bool, 0) -instantiate_gather4("int64", int64_t, bool, 0) -instantiate_gather4("float16", half, bool, 0) -instantiate_gather4("float32", float, bool, 0) -instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on - -// clang-format off -#define instantiate_gather3(name, src_type, ind_type) \ - instantiate_gather4(name, src_type, ind_type, 1) \ - instantiate_gather4(name, src_type, ind_type, 2) \ - instantiate_gather4(name, src_type, ind_type, 3) \ - instantiate_gather4(name, src_type, ind_type, 4) \ - instantiate_gather4(name, src_type, ind_type, 5) \ - instantiate_gather4(name, src_type, ind_type, 6) \ - instantiate_gather4(name, src_type, ind_type, 7) \ - instantiate_gather4(name, src_type, ind_type, 8) \ - instantiate_gather4(name, src_type, ind_type, 9) \ - instantiate_gather4(name, src_type, ind_type, 10) // clang-format on - -// clang-format off -#define instantiate_gather(name, src_type) \ - instantiate_gather3(#name "bool_", src_type, bool) \ - instantiate_gather3(#name "uint8", src_type, uint8_t) \ - instantiate_gather3(#name "uint16", src_type, uint16_t) \ - instantiate_gather3(#name "uint32", src_type, uint32_t) \ - instantiate_gather3(#name "uint64", src_type, uint64_t) \ - instantiate_gather3(#name "int8", src_type, int8_t) \ - instantiate_gather3(#name "int16", src_type, int16_t) \ - instantiate_gather3(#name "int32", src_type, int32_t) \ - instantiate_gather3(#name "int64", src_type, int64_t) - -instantiate_gather(bool_, bool) -instantiate_gather(uint8, uint8_t) -instantiate_gather(uint16, uint16_t) -instantiate_gather(uint32, uint32_t) -instantiate_gather(uint64, uint64_t) -instantiate_gather(int8, int8_t) -instantiate_gather(int16, int16_t) -instantiate_gather(int32, int32_t) -instantiate_gather(int64, int64_t) -instantiate_gather(float16, half) -instantiate_gather(float32, float) -instantiate_gather(bfloat16, bfloat16_t) // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/kernels/indexing.h b/mlx/backend/metal/kernels/indexing.h index c2b37f3ff..9f76e4771 100644 --- a/mlx/backend/metal/kernels/indexing.h +++ b/mlx/backend/metal/kernels/indexing.h @@ -1,13 +1,9 @@ // Copyright © 2023-2024 Apple Inc. +#pragma once + #include -using namespace metal; - -///////////////////////////////////////////////////////////////////// -// Indexing utils -///////////////////////////////////////////////////////////////////// - template struct Indices { const array buffers; @@ -24,31 +20,3 @@ METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { return (idx < 0) ? idx + size : idx; } } - -#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]], - -#define IDX_ARG_0(idx_t) -#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21) -#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22) -#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23) -#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24) -#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25) -#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26) -#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27) -#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28) -#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29) -#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30) - -#define IDX_ARR_N(n) idx##n, - -#define IDX_ARR_0() -#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21) -#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22) -#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23) -#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24) -#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25) -#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26) -#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27) -#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28) -#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29) -#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/reduction.h b/mlx/backend/metal/kernels/reduction.h new file mode 100644 index 000000000..279a7afe7 --- /dev/null +++ b/mlx/backend/metal/kernels/reduction.h @@ -0,0 +1,6 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/atomic.h" +#include "mlx/backend/metal/kernels/reduction/ops.h" diff --git a/mlx/backend/metal/kernels/reduction/ops.h b/mlx/backend/metal/kernels/reduction/ops.h index ea0c495d9..5996c2071 100644 --- a/mlx/backend/metal/kernels/reduction/ops.h +++ b/mlx/backend/metal/kernels/reduction/ops.h @@ -5,9 +5,11 @@ #include #include +#ifndef MLX_METAL_JIT #include "mlx/backend/metal/kernels/atomic.h" #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h" +#endif union bool4_or_uint { bool4 b; diff --git a/mlx/backend/metal/kernels/scatter.h b/mlx/backend/metal/kernels/scatter.h new file mode 100644 index 000000000..785c6134c --- /dev/null +++ b/mlx/backend/metal/kernels/scatter.h @@ -0,0 +1,66 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/indexing.h" + +template +METAL_FUNC void scatter_1d_index_impl( + const device T* updates [[buffer(1)]], + device mlx_atomic* out [[buffer(2)]], + const constant int* out_shape [[buffer(3)]], + const constant size_t* out_strides [[buffer(4)]], + const constant size_t& upd_size [[buffer(5)]], + const thread array& idx_buffers, + uint2 gid [[thread_position_in_grid]]) { + Op op; + + uint out_idx = 0; + for (int i = 0; i < NIDX; i++) { + auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]); + out_idx += idx_val * out_strides[i]; + } + + op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x); +} + +template +METAL_FUNC void scatter_impl( + const device T* updates [[buffer(1)]], + device mlx_atomic* out [[buffer(2)]], + const constant int* upd_shape [[buffer(3)]], + const constant size_t* upd_strides [[buffer(4)]], + const constant size_t& upd_ndim [[buffer(5)]], + const constant size_t& upd_size [[buffer(6)]], + const constant int* out_shape [[buffer(7)]], + const constant size_t* out_strides [[buffer(8)]], + const constant size_t& out_ndim [[buffer(9)]], + const constant int* axes [[buffer(10)]], + const thread Indices& indices, + uint2 gid [[thread_position_in_grid]]) { + Op op; + auto ind_idx = gid.y; + auto ind_offset = gid.x; + + size_t out_idx = 0; + for (int i = 0; i < NIDX; ++i) { + auto idx_loc = elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + auto ax = axes[i]; + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); + out_idx += idx_val * out_strides[ax]; + } + + if (upd_size > 1) { + auto out_offset = elem_to_loc( + ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); + out_idx += out_offset; + } + + auto upd_idx = + elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); + op.atomic_update(out, updates[upd_idx], out_idx); +} diff --git a/mlx/backend/metal/kernels/scatter.metal b/mlx/backend/metal/kernels/scatter.metal deleted file mode 100644 index 89cc6f22d..000000000 --- a/mlx/backend/metal/kernels/scatter.metal +++ /dev/null @@ -1,236 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include - -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/indexing.h" -#include "mlx/backend/metal/kernels/reduction/ops.h" -#include "mlx/backend/metal/kernels/utils.h" - -using namespace metal; - -///////////////////////////////////////////////////////////////////// -// Scatter kernel -///////////////////////////////////////////////////////////////////// - -template -METAL_FUNC void scatter_1d_index_impl( - const device T* updates [[buffer(1)]], - device mlx_atomic* out [[buffer(2)]], - const constant int* out_shape [[buffer(3)]], - const constant size_t* out_strides [[buffer(4)]], - const constant size_t& upd_size [[buffer(5)]], - const thread array& idx_buffers, - uint2 gid [[thread_position_in_grid]]) { - Op op; - - uint out_idx = 0; - for (int i = 0; i < NIDX; i++) { - auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]); - out_idx += idx_val * out_strides[i]; - } - - op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x); -} - -#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \ - template \ - [[kernel]] void scatter_1d_index( \ - const device T* updates [[buffer(1)]], \ - device mlx_atomic* out [[buffer(2)]], \ - const constant int* out_shape [[buffer(3)]], \ - const constant size_t* out_strides [[buffer(4)]], \ - const constant size_t& upd_size [[buffer(5)]], \ - IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \ - const array idx_buffers = {IDX_ARR()}; \ - \ - return scatter_1d_index_impl( \ - updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); \ - } - -template -METAL_FUNC void scatter_impl( - const device T* updates [[buffer(1)]], - device mlx_atomic* out [[buffer(2)]], - const constant int* upd_shape [[buffer(3)]], - const constant size_t* upd_strides [[buffer(4)]], - const constant size_t& upd_ndim [[buffer(5)]], - const constant size_t& upd_size [[buffer(6)]], - const constant int* out_shape [[buffer(7)]], - const constant size_t* out_strides [[buffer(8)]], - const constant size_t& out_ndim [[buffer(9)]], - const constant int* axes [[buffer(10)]], - const thread Indices& indices, - uint2 gid [[thread_position_in_grid]]) { - Op op; - auto ind_idx = gid.y; - auto ind_offset = gid.x; - - size_t out_idx = 0; - for (int i = 0; i < NIDX; ++i) { - auto idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); - auto ax = axes[i]; - auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); - out_idx += idx_val * out_strides[ax]; - } - - if (upd_size > 1) { - auto out_offset = elem_to_loc( - ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); - out_idx += out_offset; - } - - auto upd_idx = - elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); - op.atomic_update(out, updates[upd_idx], out_idx); -} - -#define make_scatter_impl(IDX_ARG, IDX_ARR) \ - template \ - [[kernel]] void scatter( \ - const device T* updates [[buffer(1)]], \ - device mlx_atomic* out [[buffer(2)]], \ - const constant int* upd_shape [[buffer(3)]], \ - const constant size_t* upd_strides [[buffer(4)]], \ - const constant size_t& upd_ndim [[buffer(5)]], \ - const constant size_t& upd_size [[buffer(6)]], \ - const constant int* out_shape [[buffer(7)]], \ - const constant size_t* out_strides [[buffer(8)]], \ - const constant size_t& out_ndim [[buffer(9)]], \ - const constant int* axes [[buffer(10)]], \ - const constant int* idx_shapes [[buffer(11)]], \ - const constant size_t* idx_strides [[buffer(12)]], \ - const constant int& idx_ndim [[buffer(13)]], \ - IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \ - Indices idxs{ \ - {{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \ - \ - return scatter_impl( \ - updates, \ - out, \ - upd_shape, \ - upd_strides, \ - upd_ndim, \ - upd_size, \ - out_shape, \ - out_strides, \ - out_ndim, \ - axes, \ - idxs, \ - gid); \ - } - -#define make_scatter(n) \ - make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \ - make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n) - -make_scatter(0) make_scatter(1) make_scatter(2) make_scatter(3) make_scatter(4) - make_scatter(5) make_scatter(6) make_scatter(7) make_scatter(8) - make_scatter(9) make_scatter(10) - -///////////////////////////////////////////////////////////////////// -// Scatter instantiations -///////////////////////////////////////////////////////////////////// - -#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ - template [[host_name("scatter" name "_" #nidx)]] [[kernel]] void \ - scatter( \ - const device src_t* updates [[buffer(1)]], \ - device mlx_atomic* out [[buffer(2)]], \ - const constant int* upd_shape [[buffer(3)]], \ - const constant size_t* upd_strides [[buffer(4)]], \ - const constant size_t& upd_ndim [[buffer(5)]], \ - const constant size_t& upd_size [[buffer(6)]], \ - const constant int* out_shape [[buffer(7)]], \ - const constant size_t* out_strides [[buffer(8)]], \ - const constant size_t& out_ndim [[buffer(9)]], \ - const constant int* axes [[buffer(10)]], \ - const constant int* idx_shapes [[buffer(11)]], \ - const constant size_t* idx_strides [[buffer(12)]], \ - const constant int& idx_ndim [[buffer(13)]], \ - IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]); - -#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ - template [[host_name("scatter_1d_index" name "_" #nidx)]] [[kernel]] void \ - scatter_1d_index( \ - const device src_t* updates [[buffer(1)]], \ - device mlx_atomic* out [[buffer(2)]], \ - const constant int* out_shape [[buffer(3)]], \ - const constant size_t* out_strides [[buffer(4)]], \ - const constant size_t& upd_size [[buffer(5)]], \ - IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]); - -// clang-format off -#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \ - instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \ - instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // clang-format on - -// clang-format off -// Special case NINDEX=0 -#define instantiate_scatter_nd0(name, type) \ - instantiate_scatter4(#name "none", type, bool, None, 0) \ - instantiate_scatter4(#name "_sum", type, bool, Sum, 0) \ - instantiate_scatter4(#name "_prod", type, bool, Prod, 0) \ - instantiate_scatter4(#name "_max", type, bool, Max, 0) \ - instantiate_scatter4(#name "_min", type, bool, Min, 0) // clang-format on - -// clang-format off -#define instantiate_scatter3(name, type, ind_type, op_type) \ - instantiate_scatter4(name, type, ind_type, op_type, 1) \ - instantiate_scatter4(name, type, ind_type, op_type, 2) \ - instantiate_scatter4(name, type, ind_type, op_type, 3) \ - instantiate_scatter4(name, type, ind_type, op_type, 4) \ - instantiate_scatter4(name, type, ind_type, op_type, 5) \ - instantiate_scatter4(name, type, ind_type, op_type, 6) \ - instantiate_scatter4(name, type, ind_type, op_type, 7) \ - instantiate_scatter4(name, type, ind_type, op_type, 8) \ - instantiate_scatter4(name, type, ind_type, op_type, 9) \ - instantiate_scatter4(name, type, ind_type, op_type, 10) // clang-format on - -// clang-format off -#define instantiate_scatter2(name, type, ind_type) \ - instantiate_scatter3(name "_none", type, ind_type, None) \ - instantiate_scatter3(name "_sum", type, ind_type, Sum) \ - instantiate_scatter3(name "_prod", type, ind_type, Prod) \ - instantiate_scatter3(name "_max", type, ind_type, Max) \ - instantiate_scatter3(name "_min", type, ind_type, Min) // clang-format on - -// clang-format off -#define instantiate_scatter(name, type) \ - instantiate_scatter2(#name "bool_", type, bool) \ - instantiate_scatter2(#name "uint8", type, uint8_t) \ - instantiate_scatter2(#name "uint16", type, uint16_t) \ - instantiate_scatter2(#name "uint32", type, uint32_t) \ - instantiate_scatter2(#name "uint64", type, uint64_t) \ - instantiate_scatter2(#name "int8", type, int8_t) \ - instantiate_scatter2(#name "int16", type, int16_t) \ - instantiate_scatter2(#name "int32", type, int32_t) \ - instantiate_scatter2(#name "int64", type, int64_t) // clang-format on - - // clang-format off -// TODO uint64 and int64 unsupported -instantiate_scatter_nd0(bool_, bool) -instantiate_scatter_nd0(uint8, uint8_t) -instantiate_scatter_nd0(uint16, uint16_t) -instantiate_scatter_nd0(uint32, uint32_t) -instantiate_scatter_nd0(int8, int8_t) -instantiate_scatter_nd0(int16, int16_t) -instantiate_scatter_nd0(int32, int32_t) -instantiate_scatter_nd0(float16, half) -instantiate_scatter_nd0(float32, float) -instantiate_scatter_nd0(bfloat16, bfloat16_t) - -instantiate_scatter(bool_, bool) -instantiate_scatter(uint8, uint8_t) -instantiate_scatter(uint16, uint16_t) -instantiate_scatter(uint32, uint32_t) -instantiate_scatter(int8, int8_t) -instantiate_scatter(int16, int16_t) -instantiate_scatter(int32, int32_t) -instantiate_scatter(float16, half) -instantiate_scatter(float32, float) -instantiate_scatter(bfloat16, bfloat16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index e0235d9dd..312a73207 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -1,10 +1,102 @@ -// Copyright © 2023-2024 Apple Inc. +// Copyright © 2024 Apple Inc. -#pragma once +template +[[kernel]] void ternary_v( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + uint index [[thread_position_in_grid]]) { + d[index] = Op()(a[index], b[index], c[index]); +} -struct Select { - template - T operator()(bool condition, T x, T y) { - return condition ? x : y; - } -}; +template +[[kernel]] void ternary_g_nd1( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const size_t& a_strides, + constant const size_t& b_strides, + constant const size_t& c_strides, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_strides); + auto b_idx = elem_to_loc_1(index, b_strides); + auto c_idx = elem_to_loc_1(index, c_strides); + d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_g_nd2( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + constant const size_t c_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + auto c_idx = elem_to_loc_2(index, c_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_g_nd3( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + constant const size_t c_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + auto c_idx = elem_to_loc_3(index, c_strides); + size_t out_idx = + index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_g_nd( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int shape[DIM], + constant const size_t a_strides[DIM], + constant const size_t b_strides[DIM], + constant const size_t c_strides[DIM], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = + elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides); + size_t out_idx = + index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); +} + +template +[[kernel]] void ternary_g( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = + elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim); + size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); + d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); +} diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index 11fd87a91..b9392eb56 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -1,115 +1,16 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2024 Apple Inc. #include #include +// clang-format off #include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/ternary.h" #include "mlx/backend/metal/kernels/utils.h" - -template -[[kernel]] void ternary_op_v( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - uint index [[thread_position_in_grid]]) { - d[index] = Op()(a[index], b[index], c[index]); -} - -template -[[kernel]] void ternary_op_g_nd1( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant const size_t& a_strides, - constant const size_t& b_strides, - constant const size_t& c_strides, - uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_strides); - auto b_idx = elem_to_loc_1(index, b_strides); - auto c_idx = elem_to_loc_1(index, c_strides); - d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); -} - -template -[[kernel]] void ternary_op_g_nd2( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant const size_t a_strides[2], - constant const size_t b_strides[2], - constant const size_t c_strides[2], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - auto c_idx = elem_to_loc_2(index, c_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * index.y; - d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); -} - -template -[[kernel]] void ternary_op_g_nd3( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant const size_t a_strides[3], - constant const size_t b_strides[3], - constant const size_t c_strides[3], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - auto c_idx = elem_to_loc_3(index, c_strides); - size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); - d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); -} - -template -[[kernel]] void ternary_op_g_nd( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant const int shape[DIM], - constant const size_t a_strides[DIM], - constant const size_t b_strides[DIM], - constant const size_t c_strides[DIM], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = - elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides); - size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); - d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); -} - -template -[[kernel]] void ternary_op_g( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - constant const size_t* c_strides, - constant const int& ndim, - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = - elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim); - size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); - d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); -} +#include "mlx/backend/metal/kernels/ternary_ops.h" +#include "mlx/backend/metal/kernels/ternary.h" #define instantiate_ternary_v(name, type, op) \ - template [[host_name(name)]] [[kernel]] void ternary_op_v( \ + template [[host_name("v_" name)]] [[kernel]] void ternary_v( \ device const bool* a, \ device const type* b, \ device const type* c, \ @@ -117,7 +18,7 @@ template uint index [[thread_position_in_grid]]); #define instantiate_ternary_g(name, type, op) \ - template [[host_name(name)]] [[kernel]] void ternary_op_g( \ + template [[host_name("g_" name)]] [[kernel]] void ternary_g( \ device const bool* a, \ device const type* b, \ device const type* c, \ @@ -131,8 +32,8 @@ template uint3 grid_dim [[threads_per_grid]]); #define instantiate_ternary_g_dim(name, type, op, dims) \ - template [[host_name(name "_" #dims)]] [[kernel]] void \ - ternary_op_g_nd( \ + template [[host_name("g" #dims "_" name )]] [[kernel]] void \ + ternary_g_nd( \ device const bool* a, \ device const type* b, \ device const type* c, \ @@ -145,8 +46,8 @@ template uint3 grid_dim [[threads_per_grid]]); #define instantiate_ternary_g_nd(name, type, op) \ - template [[host_name(name "_1")]] [[kernel]] void \ - ternary_op_g_nd1( \ + template [[host_name("g1_" name)]] [[kernel]] void \ + ternary_g_nd1( \ device const bool* a, \ device const type* b, \ device const type* c, \ @@ -155,8 +56,8 @@ template constant const size_t& b_strides, \ constant const size_t& c_strides, \ uint index [[thread_position_in_grid]]); \ - template [[host_name(name "_2")]] [[kernel]] void \ - ternary_op_g_nd2( \ + template [[host_name("g2_" name)]] [[kernel]] void \ + ternary_g_nd2( \ device const bool* a, \ device const type* b, \ device const type* c, \ @@ -166,8 +67,8 @@ template constant const size_t c_strides[2], \ uint2 index [[thread_position_in_grid]], \ uint2 grid_dim [[threads_per_grid]]); \ - template [[host_name(name "_3")]] [[kernel]] void \ - ternary_op_g_nd3( \ + template [[host_name("g3_" name)]] [[kernel]] void \ + ternary_g_nd3( \ device const bool* a, \ device const type* b, \ device const type* c, \ @@ -178,15 +79,13 @@ template uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); \ instantiate_ternary_g_dim(name, type, op, 4) \ - instantiate_ternary_g_dim(name, type, op, 5) + instantiate_ternary_g_dim(name, type, op, 5) -// clang-format off #define instantiate_ternary_all(name, tname, type, op) \ - instantiate_ternary_v("v" #name #tname, type, op) \ - instantiate_ternary_g("g" #name #tname, type, op) \ - instantiate_ternary_g_nd("g" #name #tname, type, op) // clang-format on + instantiate_ternary_v(#name #tname, type, op) \ + instantiate_ternary_g(#name #tname, type, op) \ + instantiate_ternary_g_nd(#name #tname, type, op) -// clang-format off #define instantiate_ternary_types(name, op) \ instantiate_ternary_all(name, bool_, bool, op) \ instantiate_ternary_all(name, uint8, uint8_t, op) \ @@ -202,4 +101,4 @@ template instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \ instantiate_ternary_all(name, complex64, complex64_t, op) // clang-format on -instantiate_ternary_types(select, Select) \ No newline at end of file +instantiate_ternary_types(select, Select) diff --git a/mlx/backend/metal/kernels/ternary_ops.h b/mlx/backend/metal/kernels/ternary_ops.h new file mode 100644 index 000000000..e0235d9dd --- /dev/null +++ b/mlx/backend/metal/kernels/ternary_ops.h @@ -0,0 +1,10 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +struct Select { + template + T operator()(bool condition, T x, T y) { + return condition ? x : y; + } +}; diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 3752f6061..80b17121c 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -1,394 +1,21 @@ -// Copyright © 2023-2024 Apple Inc. +// Copyright © 2024 Apple Inc. -#pragma once - -#include -#include - -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/erf.h" -#include "mlx/backend/metal/kernels/expm1f.h" -#include "mlx/backend/metal/kernels/utils.h" - -namespace { -constant float inf = metal::numeric_limits::infinity(); +template +[[kernel]] void unary_v( + device const T* in, + device T* out, + uint index [[thread_position_in_grid]]) { + out[index] = Op()(in[index]); } -struct Abs { - template - T operator()(T x) { - return metal::abs(x); - }; - template <> - uint8_t operator()(uint8_t x) { - return x; - }; - template <> - uint16_t operator()(uint16_t x) { - return x; - }; - template <> - uint32_t operator()(uint32_t x) { - return x; - }; - template <> - uint64_t operator()(uint64_t x) { - return x; - }; - template <> - bool operator()(bool x) { - return x; - }; - template <> - complex64_t operator()(complex64_t x) { - return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; - }; -}; - -struct ArcCos { - template - T operator()(T x) { - return metal::precise::acos(x); - }; -}; - -struct ArcCosh { - template - T operator()(T x) { - return metal::precise::acosh(x); - }; -}; - -struct ArcSin { - template - T operator()(T x) { - return metal::precise::asin(x); - }; -}; - -struct ArcSinh { - template - T operator()(T x) { - return metal::precise::asinh(x); - }; -}; - -struct ArcTan { - template - T operator()(T x) { - return metal::precise::atan(x); - }; -}; - -struct ArcTanh { - template - T operator()(T x) { - return metal::precise::atanh(x); - }; -}; - -struct Ceil { - template - T operator()(T x) { - return metal::ceil(x); - }; - template <> - int8_t operator()(int8_t x) { - return x; - }; - template <> - int16_t operator()(int16_t x) { - return x; - }; - template <> - int32_t operator()(int32_t x) { - return x; - }; - template <> - int64_t operator()(int64_t x) { - return x; - }; - template <> - uint8_t operator()(uint8_t x) { - return x; - }; - template <> - uint16_t operator()(uint16_t x) { - return x; - }; - template <> - uint32_t operator()(uint32_t x) { - return x; - }; - template <> - uint64_t operator()(uint64_t x) { - return x; - }; - template <> - bool operator()(bool x) { - return x; - }; -}; - -struct Cos { - template - T operator()(T x) { - return metal::precise::cos(x); - }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::cos(x.real) * metal::precise::cosh(x.imag), - -metal::precise::sin(x.real) * metal::precise::sinh(x.imag)}; - }; -}; - -struct Cosh { - template - T operator()(T x) { - return metal::precise::cosh(x); - }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::cosh(x.real) * metal::precise::cos(x.imag), - metal::precise::sinh(x.real) * metal::precise::sin(x.imag)}; - }; -}; - -struct Conjugate { - complex64_t operator()(complex64_t x) { - return complex64_t{x.real, -x.imag}; - } -}; - -struct Erf { - template - T operator()(T x) { - return static_cast(erf(static_cast(x))); - }; -}; - -struct ErfInv { - template - T operator()(T x) { - return static_cast(erfinv(static_cast(x))); - }; -}; - -struct Exp { - template - T operator()(T x) { - return metal::precise::exp(x); - }; - template <> - complex64_t operator()(complex64_t x) { - auto m = metal::precise::exp(x.real); - return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; - } -}; - -struct Expm1 { - template - T operator()(T x) { - return static_cast(expm1f(static_cast(x))); - }; -}; - -struct Floor { - template - T operator()(T x) { - return metal::floor(x); - }; - template <> - int8_t operator()(int8_t x) { - return x; - }; - template <> - int16_t operator()(int16_t x) { - return x; - }; - template <> - int32_t operator()(int32_t x) { - return x; - }; - template <> - int64_t operator()(int64_t x) { - return x; - }; - template <> - uint8_t operator()(uint8_t x) { - return x; - }; - template <> - uint16_t operator()(uint16_t x) { - return x; - }; - template <> - uint32_t operator()(uint32_t x) { - return x; - }; - template <> - uint64_t operator()(uint64_t x) { - return x; - }; - template <> - bool operator()(bool x) { - return x; - }; -}; - -struct Log { - template - T operator()(T x) { - return metal::precise::log(x); - }; -}; - -struct Log2 { - template - T operator()(T x) { - return metal::precise::log2(x); - }; -}; - -struct Log10 { - template - T operator()(T x) { - return metal::precise::log10(x); - }; -}; - -struct Log1p { - template - T operator()(T x) { - return log1p(x); - }; -}; - -struct LogicalNot { - template - T operator()(T x) { - return !x; - }; -}; - -struct Negative { - template - T operator()(T x) { - return -x; - }; -}; - -struct Round { - template - T operator()(T x) { - return metal::rint(x); - }; - template <> - complex64_t operator()(complex64_t x) { - return {metal::rint(x.real), metal::rint(x.imag)}; - }; -}; - -struct Sigmoid { - template - T operator()(T x) { - auto y = 1 / (1 + metal::exp(-metal::abs(x))); - return (x < 0) ? 1 - y : y; - } -}; - -struct Sign { - template - T operator()(T x) { - return (x > T(0)) - (x < T(0)); - }; - template <> - uint32_t operator()(uint32_t x) { - return x != 0; - }; -}; - -struct Sin { - template - T operator()(T x) { - return metal::precise::sin(x); - }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::sin(x.real) * metal::precise::cosh(x.imag), - metal::precise::cos(x.real) * metal::precise::sinh(x.imag)}; - }; -}; - -struct Sinh { - template - T operator()(T x) { - return metal::precise::sinh(x); - }; - - template <> - complex64_t operator()(complex64_t x) { - return { - metal::precise::sinh(x.real) * metal::precise::cos(x.imag), - metal::precise::cosh(x.real) * metal::precise::sin(x.imag)}; - }; -}; - -struct Square { - template - T operator()(T x) { - return x * x; - }; -}; - -struct Sqrt { - template - T operator()(T x) { - return metal::precise::sqrt(x); - }; -}; - -struct Rsqrt { - template - T operator()(T x) { - return metal::precise::rsqrt(x); - }; -}; - -struct Tan { - template - T operator()(T x) { - return metal::precise::tan(x); - }; - - template <> - complex64_t operator()(complex64_t x) { - float tan_a = metal::precise::tan(x.real); - float tanh_b = metal::precise::tanh(x.imag); - float t1 = tan_a * tanh_b; - float denom = 1. + t1 * t1; - return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; - }; -}; - -struct Tanh { - template - T operator()(T x) { - return metal::precise::tanh(x); - }; - - template <> - complex64_t operator()(complex64_t x) { - float tanh_a = metal::precise::tanh(x.real); - float tan_b = metal::precise::tan(x.imag); - float t1 = tanh_a * tan_b; - float denom = 1. + t1 * t1; - return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; - }; -}; +template +[[kernel]] void unary_g( + device const T* in, + device T* out, + device const int* in_shape, + device const size_t* in_strides, + device const int& ndim, + uint index [[thread_position_in_grid]]) { + auto idx = elem_to_loc(index, in_shape, in_strides, ndim); + out[index] = Op()(in[idx]); +} diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index c1864ff14..002a5a24f 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -1,54 +1,34 @@ -// Copyright © 2023-2024 Apple Inc. +// Copyright © 2024 Apple Inc. +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary.h" -template -[[kernel]] void unary_op_v( - device const T* in, - device T* out, - uint index [[thread_position_in_grid]]) { - out[index] = Op()(in[index]); -} - -template -[[kernel]] void unary_op_g( - device const T* in, - device T* out, - device const int* in_shape, - device const size_t* in_strides, - device const int& ndim, - uint index [[thread_position_in_grid]]) { - auto idx = elem_to_loc(index, in_shape, in_strides, ndim); - out[index] = Op()(in[idx]); -} - -#define instantiate_unary_v(name, type, op) \ - template [[host_name(name)]] [[kernel]] void unary_op_v( \ - device const type* in, \ - device type* out, \ +#define instantiate_unary_v(name, type, op) \ + template [[host_name(name)]] [[kernel]] void unary_v( \ + device const type* in, \ + device type* out, \ uint index [[thread_position_in_grid]]); -#define instantiate_unary_g(name, type, op) \ - template [[host_name(name)]] [[kernel]] void unary_op_g( \ - device const type* in, \ - device type* out, \ - device const int* in_shape, \ - device const size_t* in_strides, \ - device const int& ndim, \ +#define instantiate_unary_g(name, type, op) \ + template [[host_name(name)]] [[kernel]] void unary_g( \ + device const type* in, \ + device type* out, \ + device const int* in_shape, \ + device const size_t* in_strides, \ + device const int& ndim, \ uint index [[thread_position_in_grid]]); -// clang-format off #define instantiate_unary_all(name, tname, type, op) \ instantiate_unary_v("v" #name #tname, type, op) \ - instantiate_unary_g("g" #name #tname, type, op) // clang-format on + instantiate_unary_g("g" #name #tname, type, op) -// clang-format off #define instantiate_unary_float(name, op) \ instantiate_unary_all(name, float16, half, op) \ instantiate_unary_all(name, float32, float, op) \ - instantiate_unary_all(name, bfloat16, bfloat16_t, op) // clang-format on + instantiate_unary_all(name, bfloat16, bfloat16_t, op) -// clang-format off #define instantiate_unary_types(name, op) \ instantiate_unary_all(name, bool_, bool, op) \ instantiate_unary_all(name, uint8, uint8_t, op) \ @@ -59,9 +39,8 @@ template instantiate_unary_all(name, int16, int16_t, op) \ instantiate_unary_all(name, int32, int32_t, op) \ instantiate_unary_all(name, int64, int64_t, op) \ - instantiate_unary_float(name, op) // clang-format on + instantiate_unary_float(name, op) -// clang-format off instantiate_unary_types(abs, Abs) instantiate_unary_float(arccos, ArcCos) instantiate_unary_float(arccosh, ArcCosh) diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h new file mode 100644 index 000000000..fb4c6dbe5 --- /dev/null +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -0,0 +1,392 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/erf.h" +#include "mlx/backend/metal/kernels/expm1f.h" + +namespace { +constant float inf = metal::numeric_limits::infinity(); +} + +struct Abs { + template + T operator()(T x) { + return metal::abs(x); + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; + template <> + complex64_t operator()(complex64_t x) { + return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; + }; +}; + +struct ArcCos { + template + T operator()(T x) { + return metal::precise::acos(x); + }; +}; + +struct ArcCosh { + template + T operator()(T x) { + return metal::precise::acosh(x); + }; +}; + +struct ArcSin { + template + T operator()(T x) { + return metal::precise::asin(x); + }; +}; + +struct ArcSinh { + template + T operator()(T x) { + return metal::precise::asinh(x); + }; +}; + +struct ArcTan { + template + T operator()(T x) { + return metal::precise::atan(x); + }; +}; + +struct ArcTanh { + template + T operator()(T x) { + return metal::precise::atanh(x); + }; +}; + +struct Ceil { + template + T operator()(T x) { + return metal::ceil(x); + }; + template <> + int8_t operator()(int8_t x) { + return x; + }; + template <> + int16_t operator()(int16_t x) { + return x; + }; + template <> + int32_t operator()(int32_t x) { + return x; + }; + template <> + int64_t operator()(int64_t x) { + return x; + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; +}; + +struct Cos { + template + T operator()(T x) { + return metal::precise::cos(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::cos(x.real) * metal::precise::cosh(x.imag), + -metal::precise::sin(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Cosh { + template + T operator()(T x) { + return metal::precise::cosh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::cosh(x.real) * metal::precise::cos(x.imag), + metal::precise::sinh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Conjugate { + complex64_t operator()(complex64_t x) { + return complex64_t{x.real, -x.imag}; + } +}; + +struct Erf { + template + T operator()(T x) { + return static_cast(erf(static_cast(x))); + }; +}; + +struct ErfInv { + template + T operator()(T x) { + return static_cast(erfinv(static_cast(x))); + }; +}; + +struct Exp { + template + T operator()(T x) { + return metal::precise::exp(x); + }; + template <> + complex64_t operator()(complex64_t x) { + auto m = metal::precise::exp(x.real); + return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; + } +}; + +struct Expm1 { + template + T operator()(T x) { + return static_cast(expm1f(static_cast(x))); + }; +}; + +struct Floor { + template + T operator()(T x) { + return metal::floor(x); + }; + template <> + int8_t operator()(int8_t x) { + return x; + }; + template <> + int16_t operator()(int16_t x) { + return x; + }; + template <> + int32_t operator()(int32_t x) { + return x; + }; + template <> + int64_t operator()(int64_t x) { + return x; + }; + template <> + uint8_t operator()(uint8_t x) { + return x; + }; + template <> + uint16_t operator()(uint16_t x) { + return x; + }; + template <> + uint32_t operator()(uint32_t x) { + return x; + }; + template <> + uint64_t operator()(uint64_t x) { + return x; + }; + template <> + bool operator()(bool x) { + return x; + }; +}; + +struct Log { + template + T operator()(T x) { + return metal::precise::log(x); + }; +}; + +struct Log2 { + template + T operator()(T x) { + return metal::precise::log2(x); + }; +}; + +struct Log10 { + template + T operator()(T x) { + return metal::precise::log10(x); + }; +}; + +struct Log1p { + template + T operator()(T x) { + return log1p(x); + }; +}; + +struct LogicalNot { + template + T operator()(T x) { + return !x; + }; +}; + +struct Negative { + template + T operator()(T x) { + return -x; + }; +}; + +struct Round { + template + T operator()(T x) { + return metal::rint(x); + }; + template <> + complex64_t operator()(complex64_t x) { + return {metal::rint(x.real), metal::rint(x.imag)}; + }; +}; + +struct Sigmoid { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? 1 - y : y; + } +}; + +struct Sign { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + }; + template <> + uint32_t operator()(uint32_t x) { + return x != 0; + }; +}; + +struct Sin { + template + T operator()(T x) { + return metal::precise::sin(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::sin(x.real) * metal::precise::cosh(x.imag), + metal::precise::cos(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Sinh { + template + T operator()(T x) { + return metal::precise::sinh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + return { + metal::precise::sinh(x.real) * metal::precise::cos(x.imag), + metal::precise::cosh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Square { + template + T operator()(T x) { + return x * x; + }; +}; + +struct Sqrt { + template + T operator()(T x) { + return metal::precise::sqrt(x); + }; +}; + +struct Rsqrt { + template + T operator()(T x) { + return metal::precise::rsqrt(x); + }; +}; + +struct Tan { + template + T operator()(T x) { + return metal::precise::tan(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + float tan_a = metal::precise::tan(x.real); + float tanh_b = metal::precise::tanh(x.imag); + float t1 = tan_a * tanh_b; + float denom = 1. + t1 * t1; + return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; + }; +}; + +struct Tanh { + template + T operator()(T x) { + return metal::precise::tanh(x); + }; + + template <> + complex64_t operator()(complex64_t x) { + float tanh_a = metal::precise::tanh(x.real); + float tan_b = metal::precise::tan(x.imag); + float t1 = tanh_a * tan_b; + float denom = 1. + t1 * t1; + return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; + }; +}; diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 641df11f0..8241b3006 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -6,6 +6,8 @@ #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/complex.h" +typedef half float16_t; + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh index dedd38a64..e397042e3 100644 --- a/mlx/backend/metal/make_compiled_preamble.sh +++ b/mlx/backend/metal/make_compiled_preamble.sh @@ -5,24 +5,25 @@ # # Copyright © 2023-24 Apple Inc. - -OUTPUT_FILE=$1 +OUTPUT_DIR=$1 CC=$2 -SRCDIR=$3 -CFLAGS=$4 +SRC_DIR=$3 +SRC_NAME=$4 +CFLAGS=$5 +INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_NAME}.h +OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp -CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h $CFLAGS 2>/dev/null) +mkdir -p $OUTPUT_DIR + +CONTENT=$($CC -I $SRC_DIR -DMLX_METAL_JIT -E -P $INPUT_FILE $CFLAGS 2>/dev/null) cat << EOF > "$OUTPUT_FILE" -// Copyright © 2023-24 Apple Inc. - namespace mlx::core::metal { -const char* get_kernel_preamble() { +const char* $SRC_NAME() { return R"preamble( $CONTENT )preamble"; - } } // namespace mlx::core::metal diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp new file mode 100644 index 000000000..b8aede7e9 --- /dev/null +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -0,0 +1,46 @@ + +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +MTL::ComputePipelineState* get_unary_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_binary_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_binary_two_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_ternary_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_copy_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out) { + return d.get_kernel(kernel_name); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index a2c3df651..e1e7996c9 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -4,366 +4,14 @@ #include #include -#include "mlx/backend/common/binary.h" -#include "mlx/backend/common/ternary.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" namespace mlx::core { -namespace { - -constexpr int METAL_MAX_INDEX_ARRAYS = 10; - -void binary_op( - const std::vector& inputs, - std::vector& outputs, - const std::string op) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, outputs[0], bopt, true); - set_binary_op_output_data(a, b, outputs[1], bopt, true); - - auto& out = outputs[0]; - if (out.size() == 0) { - return; - } - - // Try to collapse contiguous dims - auto [shape, strides] = collapse_contiguous_dims(a, b, out); - auto& strides_a = strides[0]; - auto& strides_b = strides[1]; - auto& strides_out = strides[2]; - - std::ostringstream kname; - switch (bopt) { - case BinaryOpType::ScalarScalar: - kname << "ss"; - break; - case BinaryOpType::ScalarVector: - kname << "sv"; - break; - case BinaryOpType::VectorScalar: - kname << "vs"; - break; - case BinaryOpType::VectorVector: - kname << "vv"; - break; - case BinaryOpType::General: - kname << "g"; - break; - } - kname << op << type_to_name(a); - if (bopt == BinaryOpType::General && - shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { - kname << "_" << shape.size(); - } - - auto& s = out.primitive().stream(); - auto& d = metal::device(s.device); - auto kernel = d.get_kernel(kname.str()); - auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); - // - If a is donated it goes to the first output - // - If b is donated it goes to the first output if a was not donated - // otherwise it goes to the second output - bool donate_a = a.data_shared_ptr() == nullptr; - bool donate_b = b.data_shared_ptr() == nullptr; - compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0); - compute_encoder.set_input_array( - donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1); - compute_encoder.set_output_array(outputs[0], 2); - compute_encoder.set_output_array(outputs[1], 3); - - if (bopt == BinaryOpType::General) { - auto ndim = shape.size(); - if (ndim > 3) { - compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); - } else { - // The shape is implicit in the grid for <= 3D - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); - } - - if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { - compute_encoder->setBytes(&ndim, sizeof(int), 7); - } - - // Launch up to 3D grid of threads - size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; - size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; - size_t rest = out.size() / (dim0 * dim1); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size != 1024) { - throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); - } - auto group_dims = get_block_dims(dim0, dim1, rest); - MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder.dispatchThreads(grid_dims, group_dims); - } else { - // Launch a 1D grid of threads - size_t nthreads = out.data_size(); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > nthreads) { - thread_group_size = nthreads; - } - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); - } -} - -void binary_op( - const std::vector& inputs, - array& out, - const std::string op) { - assert(inputs.size() == 2); - auto& a = inputs[0]; - auto& b = inputs[1]; - auto bopt = get_binary_op_type(a, b); - set_binary_op_output_data(a, b, out, bopt, true); - if (out.size() == 0) { - return; - } - - // Try to collapse contiguous dims - auto [shape, strides] = collapse_contiguous_dims(a, b, out); - auto& strides_a = strides[0]; - auto& strides_b = strides[1]; - auto& strides_out = strides[2]; - - std::ostringstream kname; - switch (bopt) { - case BinaryOpType::ScalarScalar: - kname << "ss"; - break; - case BinaryOpType::ScalarVector: - kname << "sv"; - break; - case BinaryOpType::VectorScalar: - kname << "vs"; - break; - case BinaryOpType::VectorVector: - kname << "vv"; - break; - case BinaryOpType::General: - kname << "g"; - break; - } - kname << op << type_to_name(a); - if (bopt == BinaryOpType::General && - shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { - kname << "_" << shape.size(); - } - - auto& s = out.primitive().stream(); - auto& d = metal::device(s.device); - auto kernel = d.get_kernel(kname.str()); - auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); - bool donate_a = a.data_shared_ptr() == nullptr; - bool donate_b = b.data_shared_ptr() == nullptr; - compute_encoder.set_input_array(donate_a ? out : a, 0); - compute_encoder.set_input_array(donate_b ? out : b, 1); - compute_encoder.set_output_array(out, 2); - - if (bopt == BinaryOpType::General) { - auto ndim = shape.size(); - if (ndim > 3) { - compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3); - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); - } else { - // The shape is implicit in the grid for <= 3D - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4); - } - - if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { - compute_encoder->setBytes(&ndim, sizeof(int), 6); - } - - // Launch up to 3D grid of threads - size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; - size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; - size_t rest = out.size() / (dim0 * dim1); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size != 1024) { - throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); - } - auto group_dims = get_block_dims(dim0, dim1, rest); - MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder.dispatchThreads(grid_dims, group_dims); - } else { - // Launch a 1D grid of threads - size_t nthreads = - bopt == BinaryOpType::General ? out.size() : out.data_size(); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > nthreads) { - thread_group_size = nthreads; - } - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); - } -} - -void ternary_op( - const std::vector& inputs, - array& out, - const std::string op) { - assert(inputs.size() == 3); - auto& a = inputs[0]; - auto& b = inputs[1]; - auto& c = inputs[2]; - TernaryOpType topt = get_ternary_op_type(a, b, c); - set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */); - - if (out.size() == 0) { - return; - } - - // Try to collapse contiguous dims - auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); - auto& strides_a = strides[0]; - auto& strides_b = strides[1]; - auto& strides_c = strides[2]; - auto& strides_out = strides[3]; - - std::ostringstream kname; - if (topt == TernaryOpType::General) { - kname << "g"; - kname << op << type_to_name(b); - if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) { - kname << "_" << shape.size(); - } - } else { - kname << "v"; - kname << op << type_to_name(b); - } - - auto& s = out.primitive().stream(); - auto& d = metal::device(s.device); - auto kernel = d.get_kernel(kname.str()); - auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_input_array(c, 2); - compute_encoder.set_output_array(out, 3); - - if (topt == TernaryOpType::General) { - auto ndim = shape.size(); - if (ndim > 3) { - compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); - compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7); - - if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { - compute_encoder->setBytes(&ndim, sizeof(int), 8); - } - } else { - // The shape is implicit in the grid for <= 3D - compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); - compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); - compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6); - } - - // Launch up to 3D grid of threads - size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; - size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; - size_t rest = out.size() / (dim0 * dim1); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size != 1024) { - throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); - } - MTL::Size group_dims = get_block_dims(dim0, dim1, rest); - MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); - compute_encoder.dispatchThreads(grid_dims, group_dims); - } else { - // Launch a 1D grid of threads - size_t nthreads = out.data_size(); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > nthreads) { - thread_group_size = nthreads; - } - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - compute_encoder.dispatchThreads(grid_dims, group_dims); - } -} - -void unary_op( - const std::vector& inputs, - array& out, - const std::string op) { - auto& in = inputs[0]; - bool contig = in.flags().contiguous; - if (contig) { - if (in.is_donatable() && in.itemsize() == out.itemsize()) { - out.move_shared_buffer(in); - } else { - out.set_data( - allocator::malloc_or_wait(in.data_size() * out.itemsize()), - in.data_size(), - in.strides(), - in.flags()); - } - } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - } - if (in.size() == 0) { - return; - } - - auto& s = out.primitive().stream(); - auto& d = metal::device(s.device); - std::string tname = type_to_name(in); - std::string opt_name = contig ? "v" : "g"; - auto kernel = d.get_kernel(opt_name + op + tname); - - size_t nthreads = contig ? in.data_size() : in.size(); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > nthreads) { - thread_group_size = nthreads; - } - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - - auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder->setComputePipelineState(kernel); - compute_encoder.set_input_array( - in.data_shared_ptr() == nullptr ? out : in, 0); - compute_encoder.set_output_array(out, 1); - if (!contig) { - compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2); - compute_encoder->setBytes( - in.strides().data(), in.ndim() * sizeof(size_t), 3); - int ndim = in.ndim(); - compute_encoder->setBytes(&ndim, sizeof(int), 4); - } - compute_encoder.dispatchThreads(grid_dims, group_dims); -} - -} // namespace - -void Abs::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "abs"); -} - -void Add::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "add"); -} - template void arange_set_scalars(T start, T next, CommandEncoder& enc) { enc->setBytes(&start, sizeof(T), 0); @@ -431,34 +79,6 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreads(grid_dims, group_dims); } -void ArcCos::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "arccos"); -} - -void ArcCosh::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "arccosh"); -} - -void ArcSin::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "arcsin"); -} - -void ArcSinh::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "arcsinh"); -} - -void ArcTan::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "arctan"); -} - -void ArcTan2::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "arctan2"); -} - -void ArcTanh::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "arctanh"); -} - void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; @@ -537,26 +157,6 @@ void AsStrided::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } -void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { - switch (op_) { - case BitwiseBinary::And: - binary_op(inputs, out, "bitwise_and"); - break; - case BitwiseBinary::Or: - binary_op(inputs, out, "bitwise_or"); - break; - case BitwiseBinary::Xor: - binary_op(inputs, out, "bitwise_xor"); - break; - case BitwiseBinary::LeftShift: - binary_op(inputs, out, "left_shift"); - break; - case BitwiseBinary::RightShift: - binary_op(inputs, out, "right_shift"); - break; - } -} - void Broadcast::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } @@ -588,29 +188,10 @@ void Concatenate::eval_gpu(const std::vector& inputs, array& out) { } } -void Conjugate::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.dtype() == complex64) { - unary_op(inputs, out, "conj"); - } else { - throw std::invalid_argument( - "[conjugate] conjugate must be called on complex input."); - } -} - void Copy::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } -void Cos::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "cos"); -} - -void Cosh::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "cosh"); -} - void CustomVJP::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -623,40 +204,6 @@ void Depends::eval_gpu( eval(inputs, outputs); } -void Divide::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "div"); -} - -void DivMod::eval_gpu( - const std::vector& inputs, - std::vector& outputs) { - binary_op(inputs, outputs, "divmod"); -} - -void Remainder::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "rem"); -} - -void Equal::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, equal_nan_ ? "naneq" : "eq"); -} - -void Erf::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "erf"); -} - -void ErfInv::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "erfinv"); -} - -void Exp::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "exp"); -} - -void Expm1::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "expm1"); -} - void Full::eval_gpu(const std::vector& inputs, array& out) { auto in = inputs[0]; CopyType ctype; @@ -670,102 +217,14 @@ void Full::eval_gpu(const std::vector& inputs, array& out) { copy_gpu(in, out, ctype); } -void Greater::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "ge"); -} - -void GreaterEqual::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "geq"); -} - -void Less::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "le"); -} - -void LessEqual::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "leq"); -} - void Load::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } -void Log::eval_gpu(const std::vector& inputs, array& out) { - switch (base_) { - case Base::e: - unary_op(inputs, out, "log"); - break; - case Base::two: - unary_op(inputs, out, "log2"); - break; - case Base::ten: - unary_op(inputs, out, "log10"); - break; - } -} - -void Log1p::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "log1p"); -} - -void LogicalNot::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "lnot"); -} - -void LogicalAnd::eval_gpu(const std::vector& inputs, array& out) { - binary_op( - inputs, - out, - "land"); // Assume "land" is the operation identifier for logical AND -} - -void LogicalOr::eval_gpu(const std::vector& inputs, array& out) { - binary_op( - inputs, - out, - "lor"); // Assume "lor" is the operation identifier for logical OR -} - -void LogAddExp::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "lae"); -} - -void Maximum::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "max"); -} - -void Minimum::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "min"); -} - void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } -void Floor::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "floor"); -} - -void Ceil::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "ceil"); -} - -void Multiply::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "mul"); -} - -void Select::eval_gpu(const std::vector& inputs, array& out) { - ternary_op(inputs, out, "select"); -} - -void Negative::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "neg"); -} - -void NotEqual::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "neq"); -} - void Pad::eval_gpu(const std::vector& inputs, array& out) { // Inputs must be base input array and scalar val array assert(inputs.size() == 2); @@ -797,10 +256,6 @@ void Pad::eval_gpu(const std::vector& inputs, array& out) { copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream()); } -void Power::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "pow"); -} - void RandomBits::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); @@ -861,51 +316,12 @@ void Reshape::eval_gpu(const std::vector& inputs, array& out) { } } -void Round::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (issubdtype(in.dtype(), inexact)) { - unary_op(inputs, out, "round"); - } else { - // No-op integer types - out.copy_shared_buffer(in); - } -} - -void Sigmoid::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "sigmoid"); -} - -void Sign::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "sign"); -} - -void Sin::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "sin"); -} - -void Sinh::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "sinh"); -} - void Split::eval_gpu( const std::vector& inputs, std::vector& outputs) { eval(inputs, outputs); } -void Square::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "square"); -} - -void Sqrt::eval_gpu(const std::vector& inputs, array& out) { - if (recip_) { - unary_op(inputs, out, "rsqrt"); - } else { - unary_op(inputs, out, "sqrt"); - } -} - void Slice::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); if (out.size() == 0) { @@ -980,18 +396,6 @@ void StopGradient::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } -void Subtract::eval_gpu(const std::vector& inputs, array& out) { - binary_op(inputs, out, "sub"); -} - -void Tan::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "tan"); -} - -void Tanh::eval_gpu(const std::vector& inputs, array& out) { - unary_op(inputs, out, "tanh"); -} - void Transpose::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp new file mode 100644 index 000000000..ed2f4c14a --- /dev/null +++ b/mlx/backend/metal/ternary.cpp @@ -0,0 +1,108 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/common/ternary.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5; + +void ternary_op( + const std::vector& inputs, + array& out, + const std::string op) { + assert(inputs.size() == 3); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& c = inputs[2]; + TernaryOpType topt = get_ternary_op_type(a, b, c); + set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */); + + if (out.size() == 0) { + return; + } + + // Try to collapse contiguous dims + auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); + auto& strides_a = strides[0]; + auto& strides_b = strides[1]; + auto& strides_c = strides[2]; + auto& strides_out = strides[3]; + + std::string kernel_name; + { + std::ostringstream kname; + if (topt == TernaryOpType::General) { + kname << "g"; + if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) { + kname << shape.size(); + } + } else { + kname << "v"; + } + kname << "_" << op << type_to_name(b); + kernel_name = kname.str(); + } + + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + + auto kernel = get_ternary_kernel(d, kernel_name, out); + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(c, 2); + compute_encoder.set_output_array(out, 3); + + if (topt == TernaryOpType::General) { + auto ndim = shape.size(); + if (ndim > 3) { + compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4); + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6); + compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7); + + if (ndim > MAX_TERNARY_SPECIALIZED_DIMS) { + compute_encoder->setBytes(&ndim, sizeof(int), 8); + } + } else { + // The shape is implicit in the grid for <= 3D + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); + compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6); + } + + // Launch up to 3D grid of threads + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t rest = out.size() / (dim0 * dim1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + } + MTL::Size group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder.dispatchThreads(grid_dims, group_dims); + } else { + // Launch a 1D grid of threads + size_t nthreads = out.data_size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder.dispatchThreads(grid_dims, group_dims); + } +} + +void Select::eval_gpu(const std::vector& inputs, array& out) { + ternary_op(inputs, out, "select"); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp new file mode 100644 index 000000000..c540f7c06 --- /dev/null +++ b/mlx/backend/metal/unary.cpp @@ -0,0 +1,206 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void unary_op( + const std::vector& inputs, + array& out, + const std::string op) { + auto& in = inputs[0]; + bool contig = in.flags().contiguous; + if (contig) { + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.move_shared_buffer(in); + } else { + out.set_data( + allocator::malloc_or_wait(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + if (in.size() == 0) { + return; + } + + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + + std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out); + auto kernel = get_unary_kernel(d, kernel_name, out); + + size_t nthreads = contig ? in.data_size() : in.size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + compute_encoder.set_input_array( + in.data_shared_ptr() == nullptr ? out : in, 0); + compute_encoder.set_output_array(out, 1); + if (!contig) { + compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2); + compute_encoder->setBytes( + in.strides().data(), in.ndim() * sizeof(size_t), 3); + int ndim = in.ndim(); + compute_encoder->setBytes(&ndim, sizeof(int), 4); + } + compute_encoder.dispatchThreads(grid_dims, group_dims); +} + +void Abs::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "abs"); +} + +void ArcCos::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "arccos"); +} + +void ArcCosh::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "arccosh"); +} + +void ArcSin::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "arcsin"); +} + +void ArcSinh::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "arcsinh"); +} + +void ArcTan::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "arctan"); +} + +void ArcTanh::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "arctanh"); +} + +void Conjugate::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == complex64) { + unary_op(inputs, out, "conj"); + } else { + throw std::invalid_argument( + "[conjugate] conjugate must be called on complex input."); + } +} + +void Cos::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "cos"); +} + +void Cosh::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "cosh"); +} + +void Erf::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "erf"); +} + +void ErfInv::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "erfinv"); +} + +void Exp::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "exp"); +} + +void Expm1::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "expm1"); +} + +void Log::eval_gpu(const std::vector& inputs, array& out) { + switch (base_) { + case Base::e: + unary_op(inputs, out, "log"); + break; + case Base::two: + unary_op(inputs, out, "log2"); + break; + case Base::ten: + unary_op(inputs, out, "log10"); + break; + } +} + +void Log1p::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "log1p"); +} + +void LogicalNot::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "lnot"); +} + +void Floor::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "floor"); +} + +void Ceil::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "ceil"); +} + +void Negative::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "neg"); +} + +void Round::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (issubdtype(in.dtype(), inexact)) { + unary_op(inputs, out, "round"); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + +void Sigmoid::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "sigmoid"); +} + +void Sign::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "sign"); +} + +void Sin::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "sin"); +} + +void Sinh::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "sinh"); +} + +void Square::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "square"); +} + +void Sqrt::eval_gpu(const std::vector& inputs, array& out) { + if (recip_) { + unary_op(inputs, out, "rsqrt"); + } else { + unary_op(inputs, out, "sqrt"); + } +} + +void Tan::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "tan"); +} + +void Tanh::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "tanh"); +} + +} // namespace mlx::core diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 6c9430f18..e5e733555 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -33,9 +33,9 @@ NO_CPU(AsType) NO_CPU(AsStrided) NO_CPU(BitwiseBinary) NO_CPU(BlockMaskedMM) -NO_CPU(BlockSparseMM) NO_CPU(Broadcast) NO_CPU(Ceil) +NO_CPU(Cholesky) NO_CPU(Concatenate) NO_CPU(Conjugate) NO_CPU(Convolution) @@ -57,6 +57,8 @@ NO_CPU(FFT) NO_CPU(Floor) NO_CPU(Full) NO_CPU(Gather) +NO_CPU(GatherMM) +NO_CPU(GatherQMM) NO_CPU(Greater) NO_CPU(GreaterEqual) NO_CPU(Less) diff --git a/mlx/io/CMakeLists.txt b/mlx/io/CMakeLists.txt index 38caeff9a..14a39df73 100644 --- a/mlx/io/CMakeLists.txt +++ b/mlx/io/CMakeLists.txt @@ -41,7 +41,7 @@ if (MLX_BUILD_GGUF) gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c ${gguflib_SOURCE_DIR}/gguflib.c) - target_link_libraries(mlx $) + target_link_libraries(mlx PRIVATE $) target_sources( mlx PRIVATE diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 4fb5d8754..4f43a72d6 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -708,6 +708,14 @@ std::pair, std::vector> Ceil::vmap( return {{ceil(inputs[0], stream())}, axes}; } +std::pair, std::vector> Cholesky::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0] >= 0 ? 0 : -1; + auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + return {{linalg::cholesky(a, upper_, stream())}, {ax}}; +} + std::vector Concatenate::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 5dd99ea60..45b573a02 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -870,7 +870,7 @@ class Equal : public UnaryPrimitive { void print(std::ostream& os) override { if (equal_nan_) { - os << "NanEqual"; + os << "NaNEqual"; } else { os << "Equal"; }