JIT compile option for binary minimization (#1091)

* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
This commit is contained in:
Awni Hannun 2024-05-22 12:57:13 -07:00 committed by GitHub
parent d568c7ee36
commit 226748b3e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
56 changed files with 3153 additions and 2605 deletions

View File

@ -114,7 +114,13 @@ jobs:
name: Run CPP tests name: Run CPP tests
command: | command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
DEVICE=cpu ./build/tests/tests - run:
name: Build small binary
command: |
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
make -j
build_release: build_release:
parameters: parameters:

View File

@ -20,6 +20,7 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON) option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON) option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION) if(NOT MLX_VERSION)
@ -109,7 +110,7 @@ elseif (MLX_BUILD_METAL)
$<INSTALL_INTERFACE:include/metal_cpp> $<INSTALL_INTERFACE:include/metal_cpp>
) )
target_link_libraries( target_link_libraries(
mlx mlx PUBLIC
${METAL_LIB} ${METAL_LIB}
${FOUNDATION_LIB} ${FOUNDATION_LIB}
${QUARTZ_LIB}) ${QUARTZ_LIB})
@ -122,7 +123,7 @@ if (MLX_BUILD_CPU)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY}) target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK) add_compile_definitions(ACCELERATE_NEW_LAPACK)
else() else()
message(STATUS "Accelerate or arm neon not found, using default backend.") message(STATUS "Accelerate or arm neon not found, using default backend.")
@ -145,7 +146,7 @@ if (MLX_BUILD_CPU)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES}) target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version # List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas. # of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED) find_package(BLAS REQUIRED)
@ -160,7 +161,7 @@ if (MLX_BUILD_CPU)
message(STATUS "Blas lib " ${BLAS_LIBRARIES}) message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES}) target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
endif() endif()
else() else()
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
@ -175,6 +176,14 @@ target_include_directories(
$<INSTALL_INTERFACE:include> $<INSTALL_INTERFACE:include>
) )
FetchContent_Declare(fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL
)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
if (MLX_BUILD_PYTHON_BINDINGS) if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.") message(STATUS "Building Python bindings.")
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)

View File

@ -163,6 +163,8 @@ should point to the path to the built metal library.
- ON - ON
* - MLX_BUILD_GGUF * - MLX_BUILD_GGUF
- ON - ON
* - MLX_METAL_JIT
- OFF
.. note:: .. note::
@ -196,9 +198,18 @@ GGUF, you can do:
cmake .. cmake ..
-DCMAKE_BUILD_TYPE=MinSizeRel \ -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \ -DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=ON \ -DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \ -DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF -DMLX_BUILD_GGUF=OFF
-DMLX_METAL_JIT=ON
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists accross reboots.
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@ -1,6 +1,8 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include <cassert>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"

View File

@ -98,12 +98,4 @@ void Cholesky::eval(const std::vector<array>& inputs, array& output) {
cholesky_impl(inputs[0], output, upper_); cholesky_impl(inputs[0], output, upper_);
} }
std::pair<std::vector<array>, std::vector<int>> Cholesky::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::cholesky(a, upper_, stream())}, {ax}};
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,33 +1,80 @@
function(make_jit_source SRC_NAME)
# This function takes a metal header file,
# runs the C preprocessesor on it, and makes
# the processed contents available as a string in a C++ function
# mlx::core::metal::${SRC_NAME}()
#
# To use the function, declare it in jit/includes.h and
# include jit/includes.h.
#
# Additional arguments to this function are treated as dependencies
# in the Cmake build system.
add_custom_command( add_custom_command(
OUTPUT compiled_preamble.cpp OUTPUT jit/${SRC_NAME}.cpp
COMMAND /bin/bash COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${CMAKE_CURRENT_BINARY_DIR}/jit
${CMAKE_C_COMPILER} ${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR} ${PROJECT_SOURCE_DIR}
${SRC_NAME}
"-D${MLX_METAL_VERSION}" "-D${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh DEPENDS make_compiled_preamble.sh
kernels/compiled_preamble.h kernels/${SRC_NAME}.h
kernels/unary.h ${ARGN}
kernels/binary.h )
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME})
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
)
endfunction(make_jit_source)
make_jit_source(
utils
kernels/bf16.h kernels/bf16.h
kernels/complex.h
)
make_jit_source(
unary_ops
kernels/erf.h kernels/erf.h
kernels/expm1f.h kernels/expm1f.h
kernels/utils.h
kernels/bf16_math.h
) )
make_jit_source(binary_ops)
add_custom_target( make_jit_source(ternary_ops)
compiled_preamble make_jit_source(
DEPENDS compiled_preamble.cpp reduction
kernels/atomic.h
kernels/reduction/ops.h
) )
make_jit_source(scatter)
make_jit_source(gather)
add_dependencies(mlx compiled_preamble) if (MLX_METAL_JIT)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
)
make_jit_source(copy)
make_jit_source(unary)
make_jit_source(binary)
make_jit_source(binary_two)
make_jit_source(ternary)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
)
endif()
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
@ -46,7 +93,8 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
) )
if (NOT MLX_METAL_PATH) if (NOT MLX_METAL_PATH)

View File

@ -0,0 +1,322 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
void binary_op(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::string kernel_name;
{
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
compute_encoder.set_output_array(outputs[0], 2);
compute_encoder.set_output_array(outputs[1], 3);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 7);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
void binary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::string kernel_name;
{
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_output_array(out, 2);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads =
bopt == BinaryOpType::General ? out.size() : out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "add");
}
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "arctan2");
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) {
case BitwiseBinary::And:
binary_op(inputs, out, "bitwise_and");
break;
case BitwiseBinary::Or:
binary_op(inputs, out, "bitwise_or");
break;
case BitwiseBinary::Xor:
binary_op(inputs, out, "bitwise_xor");
break;
case BitwiseBinary::LeftShift:
binary_op(inputs, out, "left_shift");
break;
case BitwiseBinary::RightShift:
binary_op(inputs, out, "right_shift");
break;
}
}
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div");
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
binary_op(inputs, outputs, "divmod");
}
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "rem");
}
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
}
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "ge");
}
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "geq");
}
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "le");
}
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "leq");
}
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "land");
}
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lor");
}
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lae");
}
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "max");
}
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "min");
}
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "mul");
}
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "neq");
}
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "pow");
}
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "sub");
}
} // namespace mlx::core

View File

@ -4,8 +4,8 @@
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/compiled_preamble.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/graph_utils.h" #include "mlx/graph_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -190,7 +190,8 @@ void Compiled::eval_gpu(
// If not we have to build it ourselves // If not we have to build it ourselves
if (lib == nullptr) { if (lib == nullptr) {
std::ostringstream kernel; std::ostringstream kernel;
kernel << metal::get_kernel_preamble() << std::endl; kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
<< metal::ternary_ops();
build_kernel( build_kernel(
kernel, kernel,
kernel_lib_ + "_contiguous", kernel_lib_ + "_contiguous",

View File

@ -1,9 +0,0 @@
// Copyright © 2023-24 Apple Inc.
#pragma once
namespace mlx::core::metal {
const char* get_kernel_preamble();
}

View File

@ -4,12 +4,14 @@
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types // If the input is donateable, we are doing a vector copy and the types
@ -62,27 +64,34 @@ void copy_gpu_inplace(
auto& strides_out_ = strides[1]; auto& strides_out_ = strides[1];
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::string kernel_name;
{
std::ostringstream kname; std::ostringstream kname;
switch (ctype) { switch (ctype) {
case CopyType::Scalar: case CopyType::Scalar:
kname << "scopy"; kname << "s";
break; break;
case CopyType::Vector: case CopyType::Vector:
kname << "vcopy"; kname << "v";
break; break;
case CopyType::General: case CopyType::General:
kname << "gcopy"; kname << "g";
break; break;
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
kname << "ggcopy"; kname << "gg";
break; break;
} }
kname << type_to_name(in) << type_to_name(out);
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) && if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << "_" << shape.size(); kname << shape.size();
} }
auto kernel = d.get_kernel(kname.str()); kname << "_copy";
kname << type_to_name(in) << type_to_name(out);
kernel_name = kname.str();
}
auto kernel = get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr; bool donate_in = in.data_shared_ptr() == nullptr;
@ -106,7 +115,7 @@ void copy_gpu_inplace(
set_vector_bytes(compute_encoder, strides_out, ndim, 4); set_vector_bytes(compute_encoder, strides_out, ndim, 4);
} }
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5); compute_encoder->setBytes(&ndim, sizeof(int), 5);
} }

View File

@ -285,7 +285,6 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
NS::Error* error = nullptr; NS::Error* error = nullptr;
auto options = MTL::CompileOptions::alloc()->init(); auto options = MTL::CompileOptions::alloc()->init();
options->setFastMathEnabled(false); options->setFastMathEnabled(false);
options->setLanguageVersion(get_metal_version()); options->setLanguageVersion(get_metal_version());
auto mtl_lib = device_->newLibrary(ns_code, options, &error); auto mtl_lib = device_->newLibrary(ns_code, options, &error);
options->release(); options->release();

View File

@ -1,24 +1,35 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <fmt/format.h>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/indexing.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
namespace { constexpr int METAL_MAX_INDEX_ARRAYS = 20;
constexpr int METAL_MAX_INDEX_ARRAYS = 10; std::pair<std::string, std::string> make_index_args(
const std::string& idx_type,
} // namespace int nidx) {
std::ostringstream idx_args;
std::ostringstream idx_arr;
for (int i = 0; i < nidx; ++i) {
idx_args << fmt::format(
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
idx_arr << fmt::format("idx{0}", i);
if (i < nidx - 1) {
idx_args << "\n";
idx_arr << ",";
}
}
return {idx_args.str(), idx_arr.str()};
}
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) { void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& src = inputs[0]; auto& src = inputs[0];
@ -42,15 +53,41 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0; int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim(); size_t ndim = src.ndim();
std::ostringstream kname; std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx; {
if (idx_ndim <= 1) { std::ostringstream kname;
kname << "_" << idx_ndim; kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
<< "_" << idx_ndim;
lib_name = kname.str();
kernel_name = lib_name;
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gather();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool";
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
// Index dimension specializations
kernel_source << fmt::format(
gather_kernels,
type_to_name(out) + idx_type_name,
out_type_str,
idx_type_str,
nidx,
idx_args,
idx_arr,
idx_ndim);
lib = d.get_library(lib_name, kernel_source.str());
} }
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
size_t slice_size = 1; size_t slice_size = 1;
@ -102,8 +139,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9); compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
// Set index buffers // Set index buffers
for (int i = 1; i < nidx + 1; ++i) { for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i); compute_encoder.set_input_array(inputs[i + 1], 20 + i);
} }
// Launch grid // Launch grid
@ -139,10 +176,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
// Get kernel name
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
int idx_ndim = nidx ? inputs[1].ndim() : 0; int idx_ndim = nidx ? inputs[1].ndim() : 0;
bool index_nd1_specialization = (idx_ndim == 1); bool index_nd1_specialization = (idx_ndim == 1);
@ -159,32 +192,85 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
index_nd1_specialization &= inputs[i].flags().row_contiguous; index_nd1_specialization &= inputs[i].flags().row_contiguous;
} }
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string op_name;
switch (reduce_type_) {
case Scatter::None:
op_name = "none";
break;
case Scatter::Sum:
op_name = "sum";
break;
case Scatter::Prod:
op_name = "prod";
break;
case Scatter::Max:
op_name = "max";
break;
case Scatter::Min:
op_name = "min";
break;
}
{
std::ostringstream kname;
if (index_nd1_specialization) { if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name; kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else { } else {
kname << "scatter" << type_to_name(out) << idx_type_name; kname << "scatter" << type_to_name(out) << idx_type_name;
} }
kname << "_" << op_name << "_" << nidx;
lib_name = kname.str();
kernel_name = kname.str();
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduction() << metal::scatter();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool";
std::string op_type;
switch (reduce_type_) { switch (reduce_type_) {
case Scatter::None: case Scatter::None:
kname << "_none"; op_type = "None";
break; break;
case Scatter::Sum: case Scatter::Sum:
kname << "_sum"; op_type = "Sum<{0}>";
break; break;
case Scatter::Prod: case Scatter::Prod:
kname << "_prod"; op_type = "Prod<{0}>";
break; break;
case Scatter::Max: case Scatter::Max:
kname << "_max"; op_type = "Max<{0}>";
break; break;
case Scatter::Min: case Scatter::Min:
kname << "_min"; op_type = "Min<{0}>";
break; break;
} }
kname << "_" << nidx; if (reduce_type_ != Scatter::None) {
op_type = fmt::format(op_type, out_type_str);
}
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
kernel_source << fmt::format(
scatter_kernels,
type_to_name(out) + idx_type_name + "_" + op_name,
out_type_str,
idx_type_str,
op_type,
nidx,
idx_args,
idx_arr);
lib = d.get_library(lib_name, kernel_source.str());
}
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kernel_name, lib);
auto& upd = inputs.back(); auto& upd = inputs.back();
size_t nthreads = upd.size(); size_t nthreads = upd.size();
@ -209,8 +295,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5); compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
// Set index buffers // Set index buffers
for (int i = 1; i < nidx + 1; ++i) { for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i); compute_encoder.set_input_array(inputs[i + 1], 20 + i);
} }
// Launch grid // Launch grid
@ -279,8 +365,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13); compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Set index buffers // Set index buffers
for (int i = 1; i < nidx + 1; ++i) { for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i); compute_encoder.set_input_array(inputs[i + 1], 20 + i);
} }
// Launch grid // Launch grid

View File

@ -0,0 +1,87 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view binary_kernels = R"(
template [[host_name("ss{0}")]] [[kernel]]
void binary_ss<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("vs{0}")]] [[kernel]]
void binary_vs<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("sv{0}")]] [[kernel]]
void binary_sv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("vv{0}")]] [[kernel]]
void binary_vv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("g4{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 4>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 5>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1{0}")]] [[kernel]] void
binary_g_nd1<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]);
template [[host_name("g2{0}")]] [[kernel]] void
binary_g_nd2<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3{0}")]] [[kernel]] void
binary_g_nd3<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gn{0}")]] [[kernel]]
void binary_g<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@ -0,0 +1,98 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view binary_two_kernels = R"(
template [[host_name("ss{0}")]] [[kernel]]
void binary_ss<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("vs{0}")]] [[kernel]]
void binary_vs<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("sv{0}")]] [[kernel]]
void binary_sv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("vv{0}")]] [[kernel]]
void binary_vv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("g4{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 4>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 5>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1{0}")]] [[kernel]] void
binary_g_nd1<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]);
template [[host_name("g2{0}")]] [[kernel]] void
binary_g_nd2<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3{0}")]] [[kernel]] void
binary_g_nd3<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gn{0}")]] [[kernel]]
void binary_g<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@ -0,0 +1,100 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view copy_kernels = R"(
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg4_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg5_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg1_{0}")]] [[kernel]] void
copy_gg_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]);
template [[host_name("gg2_{0}")]] [[kernel]] void
copy_gg_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]);
template [[host_name("gg3_{0}")]] [[kernel]] void
copy_gg_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]);
)";

View File

@ -0,0 +1,21 @@
// Copyright © 2023-24 Apple Inc.
#pragma once
namespace mlx::core::metal {
const char* utils();
const char* binary_ops();
const char* unary_ops();
const char* ternary_ops();
const char* reduction();
const char* gather();
const char* scatter();
const char* unary();
const char* binary();
const char* binary_two();
const char* copy();
const char* ternary();
} // namespace mlx::core::metal

View File

@ -0,0 +1,81 @@
// Copyright © 2023-2024 Apple Inc.
constexpr std::string_view gather_kernels = R"(
[[kernel]] void gather{0}_{3}_{6}(
const device {1}* src [[buffer(0)]],
device {1}* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const constant int* idx_shapes [[buffer(7)]],
const constant size_t* idx_strides [[buffer(8)]],
const constant int& idx_ndim [[buffer(9)]],
{4}
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {{
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}>(
src,
out,
src_shape,
src_strides,
src_ndim,
slice_sizes,
axes,
idxs,
index,
grid_dim);
}}
)";
constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter_1d_index{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
}}
[[kernel]] void scatter{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const constant int* idx_shapes [[buffer(11)]],
const constant size_t* idx_strides [[buffer(12)]],
const constant int& idx_ndim [[buffer(13)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}>(
updates,
out,
upd_shape,
upd_strides,
upd_ndim,
upd_size,
out_shape,
out_strides,
out_ndim,
axes,
idxs,
gid);
}}
)";

View File

@ -0,0 +1,80 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view ternary_kernels = R"(
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void
ternary_g_nd1<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t& a_strides,
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void
ternary_g_nd2<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void
ternary_g_nd3<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
ternary_g_nd<{1}, {2}, 4>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
constant const size_t c_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
ternary_g_nd<{1}, {2}, 5>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
constant const size_t c_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@ -0,0 +1,16 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view unary_kernels = R"(
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
device const {1}* in,
device {1}* out,
uint index [[thread_position_in_grid]]);
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
device const {1}* in,
device {1}* out,
device const int* in_shape,
device const size_t* in_strides,
device const int& ndim,
uint index [[thread_position_in_grid]]);
)";

View File

@ -0,0 +1,124 @@
// Copyright © 2024 Apple Inc.
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/binary.h"
#include "mlx/backend/metal/jit/binary_two.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/ternary.h"
#include "mlx/backend/metal/jit/unary.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
namespace mlx::core {
std::string op_name(const array& arr) {
std::ostringstream op_t;
arr.primitive().print(op_t);
return op_t.str();
}
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
<< fmt::format(
unary_kernels,
lib_name,
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary()
<< fmt::format(
binary_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two()
<< fmt::format(
binary_two_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary()
<< fmt::format(
ternary_kernels,
lib_name,
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::copy()
<< fmt::format(
copy_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
} // namespace mlx::core

View File

@ -0,0 +1,36 @@
// Copyright © 2024 Apple Inc.
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
namespace mlx::core {
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
} // namespace mlx::core

View File

@ -3,13 +3,8 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h ${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h ${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h ${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
${CMAKE_CURRENT_SOURCE_DIR}/binary.h
${CMAKE_CURRENT_SOURCE_DIR}/complex.h ${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h ${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h ${CMAKE_CURRENT_SOURCE_DIR}/utils.h
) )
@ -17,10 +12,7 @@ set(
KERNELS KERNELS
"arange" "arange"
"arg_reduce" "arg_reduce"
"binary"
"binary_two"
"conv" "conv"
"copy"
"fft" "fft"
"gemv" "gemv"
"quantized" "quantized"
@ -32,12 +24,30 @@ set(
"scaled_dot_product_attention" "scaled_dot_product_attention"
"softmax" "softmax"
"sort" "sort"
"ternary"
"unary"
"gather"
"scatter"
) )
if (NOT MLX_METAL_JIT)
set(
KERNELS
${KERNELS}
"binary"
"binary_two"
"unary"
"ternary"
"copy"
)
set(
HEADERS
${HEADERS}
unary_ops.h
unary.h
binary_ops.h
binary.h
ternary.h
copy.h
)
endif()
function(build_kernel_base TARGET SRCFILE DEPS) function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION}) set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)

View File

@ -2,9 +2,11 @@
#pragma once #pragma once
#ifndef MLX_METAL_JIT
#include <metal_atomic> #include <metal_atomic>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#endif
using namespace metal; using namespace metal;

View File

@ -1,273 +1,113 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2024 Apple Inc.
#pragma once template <typename T, typename U, typename Op>
[[kernel]] void binary_ss(
#include <metal_integer> device const T* a,
#include <metal_math> device const T* b,
device U* c,
#include "mlx/backend/metal/kernels/bf16.h" uint index [[thread_position_in_grid]]) {
#include "mlx/backend/metal/kernels/utils.h" c[index] = Op()(a[0], b[0]);
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
if (metal::isnan(x) || metal::isnan(y)) {
return metal::numeric_limits<T>::quiet_NaN();
}
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf)
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::max(x, y);
} }
template <typename T> template <typename T, typename U, typename Op>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) { [[kernel]] void binary_sv(
if (metal::isnan(x)) { device const T* a,
return x; device const T* b,
} device U* c,
return x > y ? x : y; uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]);
} }
template <> template <typename T, typename U, typename Op>
complex64_t operator()(complex64_t x, complex64_t y) { [[kernel]] void binary_vs(
if (metal::isnan(x.real) || metal::isnan(x.imag)) { device const T* a,
return x; device const T* b,
} device U* c,
return x > y ? x : y; uint index [[thread_position_in_grid]]) {
} c[index] = Op()(a[index], b[0]);
};
struct Minimum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::min(x, y);
} }
template <typename T> template <typename T, typename U, typename Op>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) { [[kernel]] void binary_vv(
if (metal::isnan(x)) { device const T* a,
return x; device const T* b,
} device U* c,
return x < y ? x : y; uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]);
} }
template <> template <typename T, typename U, typename Op>
complex64_t operator()(complex64_t x, complex64_t y) { [[kernel]] void binary_g_nd1(
if (metal::isnan(x.real) || metal::isnan(x.imag)) { device const T* a,
return x; device const T* b,
} device U* c,
return x < y ? x : y; constant const size_t& a_stride,
} constant const size_t& b_stride,
}; uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
struct Multiply { auto b_idx = elem_to_loc_1(index, b_stride);
template <typename T> c[index] = Op()(a[a_idx], b[b_idx]);
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
} }
template <typename T> template <typename T, typename U, typename Op>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) { [[kernel]] void binary_g_nd2(
T res = 1; device const T* a,
while (exp) { device const T* b,
if (exp & 1) { device U* c,
res *= base; constant const size_t a_strides[2],
} constant const size_t b_strides[2],
exp >>= 1; uint2 index [[thread_position_in_grid]],
base *= base; uint2 grid_dim [[threads_per_grid]]) {
} auto a_idx = elem_to_loc_2(index, a_strides);
return res; auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
} }
template <> template <typename T, typename U, typename Op>
complex64_t operator()(complex64_t x, complex64_t y) { [[kernel]] void binary_g_nd3(
auto x_theta = metal::atan(x.imag / x.real); device const T* a,
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); device const T* b,
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); device U* c,
auto phase = y.imag * x_ln_r + y.real * x_theta; constant const size_t a_strides[3],
return {mag * metal::cos(phase), mag * metal::sin(phase)}; constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
} }
};
struct Subtract { template <typename T, typename U, typename Op, int DIM>
template <typename T> [[kernel]] void binary_g_nd(
T operator()(T x, T y) { device const T* a,
return x - y; device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
} }
};
struct LogicalAnd { template <typename T, typename U, typename Op>
template <typename T> [[kernel]] void binary_g(
T operator()(T x, T y) { device const T* a,
return x && y; device const T* b,
}; device U* c,
}; constant const int* shape,
constant const size_t* a_strides,
struct LogicalOr { constant const size_t* b_strides,
template <typename T> constant const int& ndim,
T operator()(T x, T y) { uint3 index [[thread_position_in_grid]],
return x || y; uint3 grid_dim [[threads_per_grid]]) {
}; auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
}; size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return metal::precise::atan2(y, x);
} }
};

View File

@ -1,130 +1,24 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <metal_integer>
#include <metal_math>
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h" #include "mlx/backend/metal/kernels/binary.h"
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_ss(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_sv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_vs(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_vv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd1(
device const T* a,
device const T* b,
device U* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd2(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd3(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_op_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g(
device const T* a,
device const T* b,
device U* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
#define instantiate_binary(name, itype, otype, op, bopt) \ #define instantiate_binary(name, itype, otype, op, bopt) \
template \ template \
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \ [[host_name(name)]] [[kernel]] void binary_##bopt<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \ #define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \ template [[host_name("g" #dims name)]] [[kernel]] void \
binary_op_g_nd<itype, otype, op, dims>( \ binary_g_nd<itype, otype, op, dims>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -135,16 +29,16 @@ template <typename T, typename U, typename Op>
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_g_nd(name, itype, otype, op) \ #define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name(name "_1")]] [[kernel]] void \ template [[host_name("g1" name)]] [[kernel]] void \
binary_op_g_nd1<itype, otype, op>( \ binary_g_nd1<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
constant const size_t& a_stride, \ constant const size_t& a_stride, \
constant const size_t& b_stride, \ constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \ template [[host_name("g2" name)]] [[kernel]] void \
binary_op_g_nd2<itype, otype, op>( \ binary_g_nd2<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -152,8 +46,8 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[2], \ constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \ uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \ uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \ template [[host_name("g3" name)]] [[kernel]] void \
binary_op_g_nd3<itype, otype, op>( \ binary_g_nd3<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -165,7 +59,7 @@ template <typename T, typename U, typename Op>
instantiate_binary_g_dim(name, itype, otype, op, 5) instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op) \ #define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \ template [[host_name("gn" name)]] [[kernel]] void binary_g<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -176,16 +70,14 @@ template <typename T, typename U, typename Op>
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op) \ #define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \ instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \ instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \ instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \ instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op) \ instantiate_binary_g(#name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on instantiate_binary_g_nd(#name #tname, itype, otype, op)
// clang-format off
#define instantiate_binary_integer(name, op) \ #define instantiate_binary_integer(name, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \ instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \ instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
@ -194,22 +86,19 @@ template <typename T, typename U, typename Op>
instantiate_binary_all(name, int8, int8_t, int8_t, op) \ instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \ instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \ instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on instantiate_binary_all(name, int64, int64_t, int64_t, op)
// clang-format off
#define instantiate_binary_float(name, op) \ #define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \ instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \ instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
// clang-format off
#define instantiate_binary_types(name, op) \ #define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \ instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_integer(name, op) \ instantiate_binary_integer(name, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \ instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op) // clang-format on instantiate_binary_float(name, op)
// clang-format off
#define instantiate_binary_types_bool(name, op) \ #define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \ instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \ instantiate_binary_all(name, uint8, uint8_t, bool, op) \
@ -223,9 +112,8 @@ template <typename T, typename U, typename Op>
instantiate_binary_all(name, float16, half, bool, op) \ instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \ instantiate_binary_all(name, float32, float, bool, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \ instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on instantiate_binary_all(name, complex64, complex64_t, bool, op)
// clang-format off
instantiate_binary_types(add, Add) instantiate_binary_types(add, Add)
instantiate_binary_types(div, Divide) instantiate_binary_types(div, Divide)
instantiate_binary_types_bool(eq, Equal) instantiate_binary_types_bool(eq, Equal)

View File

@ -0,0 +1,296 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_integer>
#include <metal_math>
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
struct FloorDivide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
template <>
float operator()(float x, float y) {
return trunc(x / y);
}
template <>
half operator()(half x, half y) {
return trunc(x / y);
}
template <>
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
return trunc(x / y);
}
};
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
if (metal::isnan(x) || metal::isnan(y)) {
return metal::numeric_limits<T>::quiet_NaN();
}
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf)
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::max(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x > y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x > y ? x : y;
}
};
struct Minimum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::min(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x < y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x < y ? x : y;
}
};
struct Multiply {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
};
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
};
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return metal::precise::atan2(y, x);
}
};
struct DivMod {
template <typename T>
metal::array<T, 2> operator()(T x, T y) {
return {FloorDivide{}(x, y), Remainder{}(x, y)};
};
};

View File

@ -0,0 +1,140 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U, typename Op>
[[kernel]] void binary_ss(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[0], b[0]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[0], b[index]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[0]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[index]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}

View File

@ -1,212 +1,24 @@
// Copyright © 2023 Apple Inc. // Copyright © 2024 Apple Inc.
#include <metal_integer> #include <metal_integer>
#include <metal_math> #include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h" // clang-format off
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
struct FloorDivide { #define instantiate_binary(name, itype, otype, op, bopt) \
template <typename T>
T operator()(T x, T y) {
return x / y;
}
template <>
float operator()(float x, float y) {
return trunc(x / y);
}
template <>
half operator()(half x, half y) {
return trunc(x / y);
}
template <>
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
return trunc(x / y);
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_s2s(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[0], b[0]);
d[index] = Op2()(a[0], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_ss(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[0], b[0]);
d[index] = Op2()(a[0], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_sv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[0], b[index]);
d[index] = Op2()(a[0], b[index]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_vs(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[index], b[0]);
d[index] = Op2()(a[index], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_vv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[index], b[index]);
d[index] = Op2()(a[index], b[index]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g_nd1(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op1()(a[a_idx], b[b_idx]);
d[index] = Op2()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g_nd2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g_nd3(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op1, typename Op2, int DIM>
[[kernel]] void binary_op_g_nd(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
}
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
template [[host_name(name)]] [[kernel]] void \ template [[host_name(name)]] [[kernel]] void \
binary_op_##bopt<itype, otype, op1, op2>( \ binary_##bopt<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
device otype* d, \ device otype* d, \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \ #define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \ template [[host_name("g" #dims name)]] [[kernel]] void \
binary_op_g_nd<itype, otype, op1, op2, dims>( \ binary_g_nd<itype, otype, op, dims>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -217,10 +29,9 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
// clang-format off #define instantiate_binary_g_nd(name, itype, otype, op) \
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \ template [[host_name("g1" name)]] [[kernel]] void \
template [[host_name(name "_1")]] [[kernel]] void \ binary_g_nd1<itype, otype, op>( \
binary_op_g_nd1<itype, otype, op1, op2>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -228,8 +39,8 @@ template <typename T, typename U, typename Op1, typename Op2>
constant const size_t& a_stride, \ constant const size_t& a_stride, \
constant const size_t& b_stride, \ constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \ template [[host_name("g2" name)]] [[kernel]] void \
binary_op_g_nd2<itype, otype, op1, op2>( \ binary_g_nd2<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -238,8 +49,8 @@ template <typename T, typename U, typename Op1, typename Op2>
constant const size_t b_strides[2], \ constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \ uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \ uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \ template [[host_name("g3" name)]] [[kernel]] void \
binary_op_g_nd3<itype, otype, op1, op2>( \ binary_g_nd3<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -248,12 +59,12 @@ template <typename T, typename U, typename Op1, typename Op2>
constant const size_t b_strides[3], \ constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \ instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op1, op2) \ #define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] [[kernel]] void \ template [[host_name("gn" name)]] [[kernel]] void \
binary_op_g<itype, otype, op2, op2>( \ binary_g<itype, otype, op>( \
device const itype* a, \ device const itype* a, \
device const itype* b, \ device const itype* b, \
device otype* c, \ device otype* c, \
@ -265,33 +76,30 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
// clang-format off #define instantiate_binary_all(name, tname, itype, otype, op) \
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \ instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \ instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \ instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \ instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \ instantiate_binary_g(#name #tname, itype, otype, op) \
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \ instantiate_binary_g_nd(#name #tname, itype, otype, op)
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
// clang-format off #define instantiate_binary_float(name, op) \
#define instantiate_binary_float(name, op1, op2) \ instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float16, half, half, op1, op2) \ instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, float32, float, float, op1, op2) \ instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
// clang-format off #define instantiate_binary_types(name, op) \
#define instantiate_binary_types(name, op1, op2) \ instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \ instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \ instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \ instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \ instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \ instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \ instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \ instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \ instantiate_binary_all(name, int64, int64_t, int64_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \ instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \ instantiate_binary_float(name, op)
instantiate_binary_float(name, op1, op2)
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on instantiate_binary_types(divmod, DivMod) // clang-format on

View File

@ -1,7 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/binary.h"
#include "mlx/backend/metal/kernels/ternary.h"
#include "mlx/backend/metal/kernels/unary.h"
typedef half float16_t;

View File

@ -0,0 +1,144 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U>
[[kernel]] void copy_s(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}

View File

@ -1,150 +1,9 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h" // clang-format off
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
template <typename T, typename U> #include "mlx/backend/metal/kernels/copy.h"
[[kernel]] void copy_s(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
#define instantiate_copy(name, itype, otype, ctype) \ #define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \ template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
@ -153,7 +12,7 @@ template <typename T, typename U>
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \ #define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \ template [[host_name("g" #dims "_" name)]] [[kernel]] void \
copy_g_nd<itype, otype, dims>( \ copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
@ -161,7 +20,7 @@ template <typename T, typename U>
constant const int64_t* src_strides [[buffer(3)]], \ constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] [[kernel]] void \ template [[host_name("gg" #dims "_" name)]] [[kernel]] void \
copy_gg_nd<itype, otype, dims>( \ copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
@ -171,38 +30,38 @@ template <typename T, typename U>
uint3 index [[thread_position_in_grid]]); uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \ #define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \ template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \ constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \ template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \ constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \ uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \ uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \ template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \ constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] [[kernel]] void \ template [[host_name("gg1_" name )]] [[kernel]] void \
copy_gg_nd1<itype, otype>( \ copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \ constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \ constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] [[kernel]] void \ template [[host_name("gg2_" name)]] [[kernel]] void \
copy_gg_nd2<itype, otype>( \ copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \ constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \ constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \ uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] [[kernel]] void \ template [[host_name("gg3_" name)]] [[kernel]] void \
copy_gg_nd3<itype, otype>( \ copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
@ -213,7 +72,7 @@ template <typename T, typename U>
instantiate_copy_g_dim(name, itype, otype, 5) instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g(name, itype, otype) \ #define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \ template [[host_name("g_" name)]] [[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \ constant const int* src_shape [[buffer(2)]], \
@ -221,7 +80,7 @@ template <typename T, typename U>
constant const int& ndim [[buffer(5)]], \ constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \ template [[host_name("gg_" name)]] [[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \ device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \ device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \ constant const int* src_shape [[buffer(2)]], \
@ -230,14 +89,12 @@ template <typename T, typename U>
constant const int& ndim [[buffer(5)]], \ constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]]); uint3 index [[thread_position_in_grid]]);
// clang-format off
#define instantiate_copy_all(tname, itype, otype) \ #define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("scopy" #tname, itype, otype, s) \ instantiate_copy("s_copy" #tname, itype, otype, s) \
instantiate_copy("vcopy" #tname, itype, otype, v) \ instantiate_copy("v_copy" #tname, itype, otype, v) \
instantiate_copy_g("gcopy" #tname, itype, otype) \ instantiate_copy_g("copy" #tname, itype, otype) \
instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on instantiate_copy_g_nd("copy" #tname, itype, otype)
// clang-format off
#define instantiate_copy_itype(itname, itype) \ #define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \ instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \ instantiate_copy_all(itname ##uint8, itype, uint8_t) \

View File

@ -8,8 +8,6 @@
#define MTL_CONST #define MTL_CONST
#endif #endif
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
static MTL_CONST constexpr int REDUCE_N_READS = 16; static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4; static MTL_CONST constexpr int SOFTMAX_N_READS = 4;

View File

@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include <metal_math> #include <metal_math>
/* /*

View File

@ -0,0 +1,45 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x;
auto ind_offset = index.y;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = ind_idx * indices.strides[indices.ndim * i];
} else {
idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx];
}

View File

@ -1,173 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_atomic>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Gather kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x;
auto ind_offset = index.y;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = ind_idx * indices.strides[indices.ndim * i];
} else {
idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx];
}
#define make_gather_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
[[kernel]] void gather( \
const device T* src [[buffer(0)]], \
device T* out [[buffer(1)]], \
const constant int* src_shape [[buffer(2)]], \
const constant size_t* src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int* slice_sizes [[buffer(5)]], \
const constant int* axes [[buffer(6)]], \
const constant int* idx_shapes [[buffer(7)]], \
const constant size_t* idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]) { \
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
\
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
src, \
out, \
src_shape, \
src_strides, \
src_ndim, \
slice_sizes, \
axes, \
idxs, \
index, \
grid_dim); \
}
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4)
make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9)
make_gather(10)
/////////////////////////////////////////////////////////////////////
// Gather instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \
gather<src_t, idx_t, nidx, nd>( \
const device src_t* src [[buffer(0)]], \
device src_t* out [[buffer(1)]], \
const constant int* src_shape [[buffer(2)]], \
const constant size_t* src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int* slice_sizes [[buffer(5)]], \
const constant int* axes [[buffer(6)]], \
const constant int* idx_shapes [[buffer(7)]], \
const constant size_t* idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on
// clang-format off
#define instantiate_gather4(name, src_t, idx_t, nidx) \
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
instantiate_gather5(name, src_t, idx_t, nidx, 2, )
// Special for case NIDX=0
instantiate_gather4("bool_", bool, bool, 0)
instantiate_gather4("uint8", uint8_t, bool, 0)
instantiate_gather4("uint16", uint16_t, bool, 0)
instantiate_gather4("uint32", uint32_t, bool, 0)
instantiate_gather4("uint64", uint64_t, bool, 0)
instantiate_gather4("int8", int8_t, bool, 0)
instantiate_gather4("int16", int16_t, bool, 0)
instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on
// clang-format off
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10) // clang-format on
// clang-format off
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
instantiate_gather(uint8, uint8_t)
instantiate_gather(uint16, uint16_t)
instantiate_gather(uint32, uint32_t)
instantiate_gather(uint64, uint64_t)
instantiate_gather(int8, int8_t)
instantiate_gather(int16, int16_t)
instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t) // clang-format on

View File

@ -1,13 +1,9 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Indexing utils
/////////////////////////////////////////////////////////////////////
template <typename IdxT, int NIDX> template <typename IdxT, int NIDX>
struct Indices { struct Indices {
const array<const device IdxT*, NIDX> buffers; const array<const device IdxT*, NIDX> buffers;
@ -24,31 +20,3 @@ METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx; return (idx < 0) ? idx + size : idx;
} }
} }
#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
#define IDX_ARG_0(idx_t)
#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
#define IDX_ARR_N(n) idx##n,
#define IDX_ARR_0()
#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)

View File

@ -0,0 +1,6 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"

View File

@ -5,9 +5,11 @@
#include <metal_atomic> #include <metal_atomic>
#include <metal_simdgroup> #include <metal_simdgroup>
#ifndef MLX_METAL_JIT
#include "mlx/backend/metal/kernels/atomic.h" #include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#endif
union bool4_or_uint { union bool4_or_uint {
bool4 b; bool4 b;

View File

@ -0,0 +1,66 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_1d_index_impl(
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
Op op;
uint out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];
}
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
}
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_impl(
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid.y;
auto ind_offset = gid.x;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
if (upd_size > 1) {
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
out_idx += out_offset;
}
auto upd_idx =
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx);
}

View File

@ -1,236 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_atomic>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_1d_index_impl(
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
Op op;
uint out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];
}
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
}
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter_1d_index( \
const device T* updates [[buffer(1)]], \
device mlx_atomic<T>* out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
\
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); \
}
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_impl(
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid.y;
auto ind_offset = gid.x;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
if (upd_size > 1) {
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
out_idx += out_offset;
}
auto upd_idx =
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx);
}
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter( \
const device T* updates [[buffer(1)]], \
device mlx_atomic<T>* out [[buffer(2)]], \
const constant int* upd_shape [[buffer(3)]], \
const constant size_t* upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int* out_shape [[buffer(7)]], \
const constant size_t* out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int* idx_shapes [[buffer(11)]], \
const constant size_t* idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
\
return scatter_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
upd_shape, \
upd_strides, \
upd_ndim, \
upd_size, \
out_shape, \
out_strides, \
out_ndim, \
axes, \
idxs, \
gid); \
}
#define make_scatter(n) \
make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \
make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n)
make_scatter(0) make_scatter(1) make_scatter(2) make_scatter(3) make_scatter(4)
make_scatter(5) make_scatter(6) make_scatter(7) make_scatter(8)
make_scatter(9) make_scatter(10)
/////////////////////////////////////////////////////////////////////
// Scatter instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter" name "_" #nidx)]] [[kernel]] void \
scatter<src_t, idx_t, op_t, nidx>( \
const device src_t* updates [[buffer(1)]], \
device mlx_atomic<src_t>* out [[buffer(2)]], \
const constant int* upd_shape [[buffer(3)]], \
const constant size_t* upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int* out_shape [[buffer(7)]], \
const constant size_t* out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int* idx_shapes [[buffer(11)]], \
const constant size_t* idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter_1d_index" name "_" #nidx)]] [[kernel]] void \
scatter_1d_index<src_t, idx_t, op_t, nidx>( \
const device src_t* updates [[buffer(1)]], \
device mlx_atomic<src_t>* out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
// clang-format off
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // clang-format on
// clang-format off
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0) // clang-format on
// clang-format off
#define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10) // clang-format on
// clang-format off
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>) // clang-format on
// clang-format off
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
instantiate_scatter2(#name "uint16", type, uint16_t) \
instantiate_scatter2(#name "uint32", type, uint32_t) \
instantiate_scatter2(#name "uint64", type, uint64_t) \
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t) // clang-format on
// clang-format off
// TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t)
instantiate_scatter_nd0(uint16, uint16_t)
instantiate_scatter_nd0(uint32, uint32_t)
instantiate_scatter_nd0(int8, int8_t)
instantiate_scatter_nd0(int16, int16_t)
instantiate_scatter_nd0(int32, int32_t)
instantiate_scatter_nd0(float16, half)
instantiate_scatter_nd0(float32, float)
instantiate_scatter_nd0(bfloat16, bfloat16_t)
instantiate_scatter(bool_, bool)
instantiate_scatter(uint8, uint8_t)
instantiate_scatter(uint16, uint16_t)
instantiate_scatter(uint32, uint32_t)
instantiate_scatter(int8, int8_t)
instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half)
instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t) // clang-format on

View File

@ -1,10 +1,102 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2024 Apple Inc.
#pragma once template <typename T, typename Op>
[[kernel]] void ternary_v(
struct Select { device const bool* a,
template <typename T> device const T* b,
T operator()(bool condition, T x, T y) { device const T* c,
return condition ? x : y; device T* d,
uint index [[thread_position_in_grid]]) {
d[index] = Op()(a[index], b[index], c[index]);
}
template <typename T, typename Op>
[[kernel]] void ternary_g_nd1(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t& a_strides,
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_strides);
auto b_idx = elem_to_loc_1(index, b_strides);
auto c_idx = elem_to_loc_1(index, c_strides);
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op>
[[kernel]] void ternary_g_nd2(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
auto c_idx = elem_to_loc_2(index, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op>
[[kernel]] void ternary_g_nd3(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3(index, c_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, int DIM>
[[kernel]] void ternary_g_nd(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
constant const size_t c_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx =
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
template <typename T, typename Op>
[[kernel]] void ternary_g(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx =
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
} }
};

View File

@ -1,115 +1,16 @@
// Copyright © 2023 Apple Inc. // Copyright © 2024 Apple Inc.
#include <metal_integer> #include <metal_integer>
#include <metal_math> #include <metal_math>
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/ternary.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/ternary_ops.h"
template <typename T, typename Op> #include "mlx/backend/metal/kernels/ternary.h"
[[kernel]] void ternary_op_v(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
uint index [[thread_position_in_grid]]) {
d[index] = Op()(a[index], b[index], c[index]);
}
template <typename T, typename Op>
[[kernel]] void ternary_op_g_nd1(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t& a_strides,
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_strides);
auto b_idx = elem_to_loc_1(index, b_strides);
auto c_idx = elem_to_loc_1(index, c_strides);
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op>
[[kernel]] void ternary_op_g_nd2(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
auto c_idx = elem_to_loc_2(index, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op>
[[kernel]] void ternary_op_g_nd3(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3(index, c_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
template <typename T, typename Op, int DIM>
[[kernel]] void ternary_op_g_nd(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
constant const size_t c_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx =
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
template <typename T, typename Op>
[[kernel]] void ternary_op_g(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx =
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
#define instantiate_ternary_v(name, type, op) \ #define instantiate_ternary_v(name, type, op) \
template [[host_name(name)]] [[kernel]] void ternary_op_v<type, op>( \ template [[host_name("v_" name)]] [[kernel]] void ternary_v<type, op>( \
device const bool* a, \ device const bool* a, \
device const type* b, \ device const type* b, \
device const type* c, \ device const type* c, \
@ -117,7 +18,7 @@ template <typename T, typename Op>
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_ternary_g(name, type, op) \ #define instantiate_ternary_g(name, type, op) \
template [[host_name(name)]] [[kernel]] void ternary_op_g<type, op>( \ template [[host_name("g_" name)]] [[kernel]] void ternary_g<type, op>( \
device const bool* a, \ device const bool* a, \
device const type* b, \ device const type* b, \
device const type* c, \ device const type* c, \
@ -131,8 +32,8 @@ template <typename T, typename Op>
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
#define instantiate_ternary_g_dim(name, type, op, dims) \ #define instantiate_ternary_g_dim(name, type, op, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \ template [[host_name("g" #dims "_" name )]] [[kernel]] void \
ternary_op_g_nd<type, op, dims>( \ ternary_g_nd<type, op, dims>( \
device const bool* a, \ device const bool* a, \
device const type* b, \ device const type* b, \
device const type* c, \ device const type* c, \
@ -145,8 +46,8 @@ template <typename T, typename Op>
uint3 grid_dim [[threads_per_grid]]); uint3 grid_dim [[threads_per_grid]]);
#define instantiate_ternary_g_nd(name, type, op) \ #define instantiate_ternary_g_nd(name, type, op) \
template [[host_name(name "_1")]] [[kernel]] void \ template [[host_name("g1_" name)]] [[kernel]] void \
ternary_op_g_nd1<type, op>( \ ternary_g_nd1<type, op>( \
device const bool* a, \ device const bool* a, \
device const type* b, \ device const type* b, \
device const type* c, \ device const type* c, \
@ -155,8 +56,8 @@ template <typename T, typename Op>
constant const size_t& b_strides, \ constant const size_t& b_strides, \
constant const size_t& c_strides, \ constant const size_t& c_strides, \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \ template [[host_name("g2_" name)]] [[kernel]] void \
ternary_op_g_nd2<type, op>( \ ternary_g_nd2<type, op>( \
device const bool* a, \ device const bool* a, \
device const type* b, \ device const type* b, \
device const type* c, \ device const type* c, \
@ -166,8 +67,8 @@ template <typename T, typename Op>
constant const size_t c_strides[2], \ constant const size_t c_strides[2], \
uint2 index [[thread_position_in_grid]], \ uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \ uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \ template [[host_name("g3_" name)]] [[kernel]] void \
ternary_op_g_nd3<type, op>( \ ternary_g_nd3<type, op>( \
device const bool* a, \ device const bool* a, \
device const type* b, \ device const type* b, \
device const type* c, \ device const type* c, \
@ -180,13 +81,11 @@ template <typename T, typename Op>
instantiate_ternary_g_dim(name, type, op, 4) \ instantiate_ternary_g_dim(name, type, op, 4) \
instantiate_ternary_g_dim(name, type, op, 5) instantiate_ternary_g_dim(name, type, op, 5)
// clang-format off
#define instantiate_ternary_all(name, tname, type, op) \ #define instantiate_ternary_all(name, tname, type, op) \
instantiate_ternary_v("v" #name #tname, type, op) \ instantiate_ternary_v(#name #tname, type, op) \
instantiate_ternary_g("g" #name #tname, type, op) \ instantiate_ternary_g(#name #tname, type, op) \
instantiate_ternary_g_nd("g" #name #tname, type, op) // clang-format on instantiate_ternary_g_nd(#name #tname, type, op)
// clang-format off
#define instantiate_ternary_types(name, op) \ #define instantiate_ternary_types(name, op) \
instantiate_ternary_all(name, bool_, bool, op) \ instantiate_ternary_all(name, bool_, bool, op) \
instantiate_ternary_all(name, uint8, uint8_t, op) \ instantiate_ternary_all(name, uint8, uint8_t, op) \

View File

@ -0,0 +1,10 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return condition ? x : y;
}
};

View File

@ -1,394 +1,21 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2024 Apple Inc.
#pragma once template <typename T, typename Op>
[[kernel]] void unary_v(
#include <metal_integer> device const T* in,
#include <metal_math> device T* out,
uint index [[thread_position_in_grid]]) {
#include "mlx/backend/metal/kernels/bf16.h" out[index] = Op()(in[index]);
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/expm1f.h"
#include "mlx/backend/metal/kernels/utils.h"
namespace {
constant float inf = metal::numeric_limits<float>::infinity();
} }
struct Abs { template <typename T, typename Op>
template <typename T> [[kernel]] void unary_g(
T operator()(T x) { device const T* in,
return metal::abs(x); device T* out,
}; device const int* in_shape,
template <> device const size_t* in_strides,
uint8_t operator()(uint8_t x) { device const int& ndim,
return x; uint index [[thread_position_in_grid]]) {
}; auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
template <> out[index] = Op()(in[idx]);
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
template <>
complex64_t operator()(complex64_t x) {
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
};
};
struct ArcCos {
template <typename T>
T operator()(T x) {
return metal::precise::acos(x);
};
};
struct ArcCosh {
template <typename T>
T operator()(T x) {
return metal::precise::acosh(x);
};
};
struct ArcSin {
template <typename T>
T operator()(T x) {
return metal::precise::asin(x);
};
};
struct ArcSinh {
template <typename T>
T operator()(T x) {
return metal::precise::asinh(x);
};
};
struct ArcTan {
template <typename T>
T operator()(T x) {
return metal::precise::atan(x);
};
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
return metal::precise::atanh(x);
};
};
struct Ceil {
template <typename T>
T operator()(T x) {
return metal::ceil(x);
};
template <>
int8_t operator()(int8_t x) {
return x;
};
template <>
int16_t operator()(int16_t x) {
return x;
};
template <>
int32_t operator()(int32_t x) {
return x;
};
template <>
int64_t operator()(int64_t x) {
return x;
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
};
struct Cos {
template <typename T>
T operator()(T x) {
return metal::precise::cos(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)};
};
};
struct Cosh {
template <typename T>
T operator()(T x) {
return metal::precise::cosh(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)};
};
};
struct Conjugate {
complex64_t operator()(complex64_t x) {
return complex64_t{x.real, -x.imag};
} }
};
struct Erf {
template <typename T>
T operator()(T x) {
return static_cast<T>(erf(static_cast<float>(x)));
};
};
struct ErfInv {
template <typename T>
T operator()(T x) {
return static_cast<T>(erfinv(static_cast<float>(x)));
};
};
struct Exp {
template <typename T>
T operator()(T x) {
return metal::precise::exp(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
}
};
struct Expm1 {
template <typename T>
T operator()(T x) {
return static_cast<T>(expm1f(static_cast<float>(x)));
};
};
struct Floor {
template <typename T>
T operator()(T x) {
return metal::floor(x);
};
template <>
int8_t operator()(int8_t x) {
return x;
};
template <>
int16_t operator()(int16_t x) {
return x;
};
template <>
int32_t operator()(int32_t x) {
return x;
};
template <>
int64_t operator()(int64_t x) {
return x;
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
};
struct Log {
template <typename T>
T operator()(T x) {
return metal::precise::log(x);
};
};
struct Log2 {
template <typename T>
T operator()(T x) {
return metal::precise::log2(x);
};
};
struct Log10 {
template <typename T>
T operator()(T x) {
return metal::precise::log10(x);
};
};
struct Log1p {
template <typename T>
T operator()(T x) {
return log1p(x);
};
};
struct LogicalNot {
template <typename T>
T operator()(T x) {
return !x;
};
};
struct Negative {
template <typename T>
T operator()(T x) {
return -x;
};
};
struct Round {
template <typename T>
T operator()(T x) {
return metal::rint(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {metal::rint(x.real), metal::rint(x.imag)};
};
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
};
struct Sign {
template <typename T>
T operator()(T x) {
return (x > T(0)) - (x < T(0));
};
template <>
uint32_t operator()(uint32_t x) {
return x != 0;
};
};
struct Sin {
template <typename T>
T operator()(T x) {
return metal::precise::sin(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)};
};
};
struct Sinh {
template <typename T>
T operator()(T x) {
return metal::precise::sinh(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)};
};
};
struct Square {
template <typename T>
T operator()(T x) {
return x * x;
};
};
struct Sqrt {
template <typename T>
T operator()(T x) {
return metal::precise::sqrt(x);
};
};
struct Rsqrt {
template <typename T>
T operator()(T x) {
return metal::precise::rsqrt(x);
};
};
struct Tan {
template <typename T>
T operator()(T x) {
return metal::precise::tan(x);
};
template <>
complex64_t operator()(complex64_t x) {
float tan_a = metal::precise::tan(x.real);
float tanh_b = metal::precise::tanh(x.imag);
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
};
};
struct Tanh {
template <typename T>
T operator()(T x) {
return metal::precise::tanh(x);
};
template <>
complex64_t operator()(complex64_t x) {
float tanh_a = metal::precise::tanh(x.real);
float tan_b = metal::precise::tan(x.imag);
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
};
};

View File

@ -1,35 +1,18 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h" #include "mlx/backend/metal/kernels/unary.h"
template <typename T, typename Op>
[[kernel]] void unary_op_v(
device const T* in,
device T* out,
uint index [[thread_position_in_grid]]) {
out[index] = Op()(in[index]);
}
template <typename T, typename Op>
[[kernel]] void unary_op_g(
device const T* in,
device T* out,
device const int* in_shape,
device const size_t* in_strides,
device const int& ndim,
uint index [[thread_position_in_grid]]) {
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
out[index] = Op()(in[idx]);
}
#define instantiate_unary_v(name, type, op) \ #define instantiate_unary_v(name, type, op) \
template [[host_name(name)]] [[kernel]] void unary_op_v<type, op>( \ template [[host_name(name)]] [[kernel]] void unary_v<type, op>( \
device const type* in, \ device const type* in, \
device type* out, \ device type* out, \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_unary_g(name, type, op) \ #define instantiate_unary_g(name, type, op) \
template [[host_name(name)]] [[kernel]] void unary_op_g<type, op>( \ template [[host_name(name)]] [[kernel]] void unary_g<type, op>( \
device const type* in, \ device const type* in, \
device type* out, \ device type* out, \
device const int* in_shape, \ device const int* in_shape, \
@ -37,18 +20,15 @@ template <typename T, typename Op>
device const int& ndim, \ device const int& ndim, \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
// clang-format off
#define instantiate_unary_all(name, tname, type, op) \ #define instantiate_unary_all(name, tname, type, op) \
instantiate_unary_v("v" #name #tname, type, op) \ instantiate_unary_v("v" #name #tname, type, op) \
instantiate_unary_g("g" #name #tname, type, op) // clang-format on instantiate_unary_g("g" #name #tname, type, op)
// clang-format off
#define instantiate_unary_float(name, op) \ #define instantiate_unary_float(name, op) \
instantiate_unary_all(name, float16, half, op) \ instantiate_unary_all(name, float16, half, op) \
instantiate_unary_all(name, float32, float, op) \ instantiate_unary_all(name, float32, float, op) \
instantiate_unary_all(name, bfloat16, bfloat16_t, op) // clang-format on instantiate_unary_all(name, bfloat16, bfloat16_t, op)
// clang-format off
#define instantiate_unary_types(name, op) \ #define instantiate_unary_types(name, op) \
instantiate_unary_all(name, bool_, bool, op) \ instantiate_unary_all(name, bool_, bool, op) \
instantiate_unary_all(name, uint8, uint8_t, op) \ instantiate_unary_all(name, uint8, uint8_t, op) \
@ -59,9 +39,8 @@ template <typename T, typename Op>
instantiate_unary_all(name, int16, int16_t, op) \ instantiate_unary_all(name, int16, int16_t, op) \
instantiate_unary_all(name, int32, int32_t, op) \ instantiate_unary_all(name, int32, int32_t, op) \
instantiate_unary_all(name, int64, int64_t, op) \ instantiate_unary_all(name, int64, int64_t, op) \
instantiate_unary_float(name, op) // clang-format on instantiate_unary_float(name, op)
// clang-format off
instantiate_unary_types(abs, Abs) instantiate_unary_types(abs, Abs)
instantiate_unary_float(arccos, ArcCos) instantiate_unary_float(arccos, ArcCos)
instantiate_unary_float(arccosh, ArcCosh) instantiate_unary_float(arccosh, ArcCosh)

View File

@ -0,0 +1,392 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/expm1f.h"
namespace {
constant float inf = metal::numeric_limits<float>::infinity();
}
struct Abs {
template <typename T>
T operator()(T x) {
return metal::abs(x);
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
template <>
complex64_t operator()(complex64_t x) {
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
};
};
struct ArcCos {
template <typename T>
T operator()(T x) {
return metal::precise::acos(x);
};
};
struct ArcCosh {
template <typename T>
T operator()(T x) {
return metal::precise::acosh(x);
};
};
struct ArcSin {
template <typename T>
T operator()(T x) {
return metal::precise::asin(x);
};
};
struct ArcSinh {
template <typename T>
T operator()(T x) {
return metal::precise::asinh(x);
};
};
struct ArcTan {
template <typename T>
T operator()(T x) {
return metal::precise::atan(x);
};
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
return metal::precise::atanh(x);
};
};
struct Ceil {
template <typename T>
T operator()(T x) {
return metal::ceil(x);
};
template <>
int8_t operator()(int8_t x) {
return x;
};
template <>
int16_t operator()(int16_t x) {
return x;
};
template <>
int32_t operator()(int32_t x) {
return x;
};
template <>
int64_t operator()(int64_t x) {
return x;
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
};
struct Cos {
template <typename T>
T operator()(T x) {
return metal::precise::cos(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)};
};
};
struct Cosh {
template <typename T>
T operator()(T x) {
return metal::precise::cosh(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)};
};
};
struct Conjugate {
complex64_t operator()(complex64_t x) {
return complex64_t{x.real, -x.imag};
}
};
struct Erf {
template <typename T>
T operator()(T x) {
return static_cast<T>(erf(static_cast<float>(x)));
};
};
struct ErfInv {
template <typename T>
T operator()(T x) {
return static_cast<T>(erfinv(static_cast<float>(x)));
};
};
struct Exp {
template <typename T>
T operator()(T x) {
return metal::precise::exp(x);
};
template <>
complex64_t operator()(complex64_t x) {
auto m = metal::precise::exp(x.real);
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
}
};
struct Expm1 {
template <typename T>
T operator()(T x) {
return static_cast<T>(expm1f(static_cast<float>(x)));
};
};
struct Floor {
template <typename T>
T operator()(T x) {
return metal::floor(x);
};
template <>
int8_t operator()(int8_t x) {
return x;
};
template <>
int16_t operator()(int16_t x) {
return x;
};
template <>
int32_t operator()(int32_t x) {
return x;
};
template <>
int64_t operator()(int64_t x) {
return x;
};
template <>
uint8_t operator()(uint8_t x) {
return x;
};
template <>
uint16_t operator()(uint16_t x) {
return x;
};
template <>
uint32_t operator()(uint32_t x) {
return x;
};
template <>
uint64_t operator()(uint64_t x) {
return x;
};
template <>
bool operator()(bool x) {
return x;
};
};
struct Log {
template <typename T>
T operator()(T x) {
return metal::precise::log(x);
};
};
struct Log2 {
template <typename T>
T operator()(T x) {
return metal::precise::log2(x);
};
};
struct Log10 {
template <typename T>
T operator()(T x) {
return metal::precise::log10(x);
};
};
struct Log1p {
template <typename T>
T operator()(T x) {
return log1p(x);
};
};
struct LogicalNot {
template <typename T>
T operator()(T x) {
return !x;
};
};
struct Negative {
template <typename T>
T operator()(T x) {
return -x;
};
};
struct Round {
template <typename T>
T operator()(T x) {
return metal::rint(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {metal::rint(x.real), metal::rint(x.imag)};
};
};
struct Sigmoid {
template <typename T>
T operator()(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
};
struct Sign {
template <typename T>
T operator()(T x) {
return (x > T(0)) - (x < T(0));
};
template <>
uint32_t operator()(uint32_t x) {
return x != 0;
};
};
struct Sin {
template <typename T>
T operator()(T x) {
return metal::precise::sin(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)};
};
};
struct Sinh {
template <typename T>
T operator()(T x) {
return metal::precise::sinh(x);
};
template <>
complex64_t operator()(complex64_t x) {
return {
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)};
};
};
struct Square {
template <typename T>
T operator()(T x) {
return x * x;
};
};
struct Sqrt {
template <typename T>
T operator()(T x) {
return metal::precise::sqrt(x);
};
};
struct Rsqrt {
template <typename T>
T operator()(T x) {
return metal::precise::rsqrt(x);
};
};
struct Tan {
template <typename T>
T operator()(T x) {
return metal::precise::tan(x);
};
template <>
complex64_t operator()(complex64_t x) {
float tan_a = metal::precise::tan(x.real);
float tanh_b = metal::precise::tanh(x.imag);
float t1 = tan_a * tanh_b;
float denom = 1. + t1 * t1;
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
};
};
struct Tanh {
template <typename T>
T operator()(T x) {
return metal::precise::tanh(x);
};
template <>
complex64_t operator()(complex64_t x) {
float tanh_a = metal::precise::tanh(x.real);
float tan_b = metal::precise::tan(x.imag);
float t1 = tanh_a * tan_b;
float denom = 1. + t1 * t1;
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
};
};

View File

@ -6,6 +6,8 @@
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/complex.h" #include "mlx/backend/metal/kernels/complex.h"
typedef half float16_t;
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Type limits utils // Type limits utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@ -5,24 +5,25 @@
# #
# Copyright © 2023-24 Apple Inc. # Copyright © 2023-24 Apple Inc.
OUTPUT_DIR=$1
OUTPUT_FILE=$1
CC=$2 CC=$2
SRCDIR=$3 SRC_DIR=$3
CFLAGS=$4 SRC_NAME=$4
CFLAGS=$5
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_NAME}.h
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h $CFLAGS 2>/dev/null) mkdir -p $OUTPUT_DIR
CONTENT=$($CC -I $SRC_DIR -DMLX_METAL_JIT -E -P $INPUT_FILE $CFLAGS 2>/dev/null)
cat << EOF > "$OUTPUT_FILE" cat << EOF > "$OUTPUT_FILE"
// Copyright © 2023-24 Apple Inc.
namespace mlx::core::metal { namespace mlx::core::metal {
const char* get_kernel_preamble() { const char* $SRC_NAME() {
return R"preamble( return R"preamble(
$CONTENT $CONTENT
)preamble"; )preamble";
} }
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@ -0,0 +1,46 @@
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
return d.get_kernel(kernel_name);
}
} // namespace mlx::core

View File

@ -4,366 +4,14 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/ternary.h"
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
namespace {
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
void binary_op(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
break;
}
kname << op << type_to_name(a);
if (bopt == BinaryOpType::General &&
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
compute_encoder.set_output_array(outputs[0], 2);
compute_encoder.set_output_array(outputs[1], 3);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 7);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
void binary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
break;
}
kname << op << type_to_name(a);
if (bopt == BinaryOpType::General &&
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_output_array(out, 2);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads =
bopt == BinaryOpType::General ? out.size() : out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
void ternary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
assert(inputs.size() == 3);
auto& a = inputs[0];
auto& b = inputs[1];
auto& c = inputs[2];
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_c = strides[2];
auto& strides_out = strides[3];
std::ostringstream kname;
if (topt == TernaryOpType::General) {
kname << "g";
kname << op << type_to_name(b);
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
} else {
kname << "v";
kname << op << type_to_name(b);
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3);
if (topt == TernaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 8);
}
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
void unary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
auto& in = inputs[0];
bool contig = in.flags().contiguous;
if (contig) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
if (in.size() == 0) {
return;
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
std::string tname = type_to_name(in);
std::string opt_name = contig ? "v" : "g";
auto kernel = d.get_kernel(opt_name + op + tname);
size_t nthreads = contig ? in.data_size() : in.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1);
if (!contig) {
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
compute_encoder->setBytes(
in.strides().data(), in.ndim() * sizeof(size_t), 3);
int ndim = in.ndim();
compute_encoder->setBytes(&ndim, sizeof(int), 4);
}
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "abs");
}
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "add");
}
template <typename T> template <typename T>
void arange_set_scalars(T start, T next, CommandEncoder& enc) { void arange_set_scalars(T start, T next, CommandEncoder& enc) {
enc->setBytes(&start, sizeof(T), 0); enc->setBytes(&start, sizeof(T), 0);
@ -431,34 +79,6 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arccos");
}
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arccosh");
}
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arcsin");
}
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arcsinh");
}
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctan");
}
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "arctan2");
}
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctanh");
}
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) { void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
@ -537,26 +157,6 @@ void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); eval(inputs, out);
} }
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) {
case BitwiseBinary::And:
binary_op(inputs, out, "bitwise_and");
break;
case BitwiseBinary::Or:
binary_op(inputs, out, "bitwise_or");
break;
case BitwiseBinary::Xor:
binary_op(inputs, out, "bitwise_xor");
break;
case BitwiseBinary::LeftShift:
binary_op(inputs, out, "left_shift");
break;
case BitwiseBinary::RightShift:
binary_op(inputs, out, "right_shift");
break;
}
}
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) { void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); eval(inputs, out);
} }
@ -588,29 +188,10 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_op(inputs, out, "conj");
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) { void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); eval(inputs, out);
} }
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cos");
}
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cosh");
}
void CustomVJP::eval_gpu( void CustomVJP::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
@ -623,40 +204,6 @@ void Depends::eval_gpu(
eval(inputs, outputs); eval(inputs, outputs);
} }
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div");
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
binary_op(inputs, outputs, "divmod");
}
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "rem");
}
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
}
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "erf");
}
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "erfinv");
}
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "exp");
}
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "expm1");
}
void Full::eval_gpu(const std::vector<array>& inputs, array& out) { void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0]; auto in = inputs[0];
CopyType ctype; CopyType ctype;
@ -670,102 +217,14 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
copy_gpu(in, out, ctype); copy_gpu(in, out, ctype);
} }
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "ge");
}
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "geq");
}
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "le");
}
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "leq");
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) { void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); eval(inputs, out);
} }
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) {
case Base::e:
unary_op(inputs, out, "log");
break;
case Base::two:
unary_op(inputs, out, "log2");
break;
case Base::ten:
unary_op(inputs, out, "log10");
break;
}
}
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "log1p");
}
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "lnot");
}
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(
inputs,
out,
"land"); // Assume "land" is the operation identifier for logical AND
}
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(
inputs,
out,
"lor"); // Assume "lor" is the operation identifier for logical OR
}
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lae");
}
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "max");
}
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "min");
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) { void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); eval(inputs, out);
} }
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "floor");
}
void Ceil::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "ceil");
}
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "mul");
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
ternary_op(inputs, out, "select");
}
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "neg");
}
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "neq");
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) { void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
// Inputs must be base input array and scalar val array // Inputs must be base input array and scalar val array
assert(inputs.size() == 2); assert(inputs.size() == 2);
@ -797,10 +256,6 @@ void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream()); copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
} }
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "pow");
}
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) { void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
@ -861,51 +316,12 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_op(inputs, out, "round");
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sigmoid");
}
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sign");
}
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sin");
}
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sinh");
}
void Split::eval_gpu( void Split::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
eval(inputs, outputs); eval(inputs, outputs);
} }
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "square");
}
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
if (recip_) {
unary_op(inputs, out, "rsqrt");
} else {
unary_op(inputs, out, "sqrt");
}
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) { void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
if (out.size() == 0) { if (out.size() == 0) {
@ -980,18 +396,6 @@ void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); eval(inputs, out);
} }
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "sub");
}
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "tan");
}
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "tanh");
}
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) { void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); eval(inputs, out);
} }

View File

@ -0,0 +1,108 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/ternary.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5;
void ternary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
assert(inputs.size() == 3);
auto& a = inputs[0];
auto& b = inputs[1];
auto& c = inputs[2];
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_c = strides[2];
auto& strides_out = strides[3];
std::string kernel_name;
{
std::ostringstream kname;
if (topt == TernaryOpType::General) {
kname << "g";
if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) {
kname << shape.size();
}
} else {
kname << "v";
}
kname << "_" << op << type_to_name(b);
kernel_name = kname.str();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = get_ternary_kernel(d, kernel_name, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3);
if (topt == TernaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
if (ndim > MAX_TERNARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 8);
}
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
ternary_op(inputs, out, "select");
}
} // namespace mlx::core

206
mlx/backend/metal/unary.cpp Normal file
View File

@ -0,0 +1,206 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void unary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
auto& in = inputs[0];
bool contig = in.flags().contiguous;
if (contig) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
if (in.size() == 0) {
return;
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out);
auto kernel = get_unary_kernel(d, kernel_name, out);
size_t nthreads = contig ? in.data_size() : in.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1);
if (!contig) {
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
compute_encoder->setBytes(
in.strides().data(), in.ndim() * sizeof(size_t), 3);
int ndim = in.ndim();
compute_encoder->setBytes(&ndim, sizeof(int), 4);
}
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "abs");
}
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arccos");
}
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arccosh");
}
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arcsin");
}
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arcsinh");
}
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctan");
}
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "arctanh");
}
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_op(inputs, out, "conj");
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
}
}
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cos");
}
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cosh");
}
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "erf");
}
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "erfinv");
}
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "exp");
}
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "expm1");
}
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (base_) {
case Base::e:
unary_op(inputs, out, "log");
break;
case Base::two:
unary_op(inputs, out, "log2");
break;
case Base::ten:
unary_op(inputs, out, "log10");
break;
}
}
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "log1p");
}
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "lnot");
}
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "floor");
}
void Ceil::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "ceil");
}
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "neg");
}
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (issubdtype(in.dtype(), inexact)) {
unary_op(inputs, out, "round");
} else {
// No-op integer types
out.copy_shared_buffer(in);
}
}
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sigmoid");
}
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sign");
}
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sin");
}
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "sinh");
}
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "square");
}
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
if (recip_) {
unary_op(inputs, out, "rsqrt");
} else {
unary_op(inputs, out, "sqrt");
}
}
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "tan");
}
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "tanh");
}
} // namespace mlx::core

View File

@ -33,9 +33,9 @@ NO_CPU(AsType)
NO_CPU(AsStrided) NO_CPU(AsStrided)
NO_CPU(BitwiseBinary) NO_CPU(BitwiseBinary)
NO_CPU(BlockMaskedMM) NO_CPU(BlockMaskedMM)
NO_CPU(BlockSparseMM)
NO_CPU(Broadcast) NO_CPU(Broadcast)
NO_CPU(Ceil) NO_CPU(Ceil)
NO_CPU(Cholesky)
NO_CPU(Concatenate) NO_CPU(Concatenate)
NO_CPU(Conjugate) NO_CPU(Conjugate)
NO_CPU(Convolution) NO_CPU(Convolution)
@ -57,6 +57,8 @@ NO_CPU(FFT)
NO_CPU(Floor) NO_CPU(Floor)
NO_CPU(Full) NO_CPU(Full)
NO_CPU(Gather) NO_CPU(Gather)
NO_CPU(GatherMM)
NO_CPU(GatherQMM)
NO_CPU(Greater) NO_CPU(Greater)
NO_CPU(GreaterEqual) NO_CPU(GreaterEqual)
NO_CPU(Less) NO_CPU(Less)

View File

@ -41,7 +41,7 @@ if (MLX_BUILD_GGUF)
gguflib STATIC gguflib STATIC
${gguflib_SOURCE_DIR}/fp16.c ${gguflib_SOURCE_DIR}/fp16.c
${gguflib_SOURCE_DIR}/gguflib.c) ${gguflib_SOURCE_DIR}/gguflib.c)
target_link_libraries(mlx $<BUILD_INTERFACE:gguflib>) target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:gguflib>)
target_sources( target_sources(
mlx mlx
PRIVATE PRIVATE

View File

@ -708,6 +708,14 @@ std::pair<std::vector<array>, std::vector<int>> Ceil::vmap(
return {{ceil(inputs[0], stream())}, axes}; return {{ceil(inputs[0], stream())}, axes};
} }
std::pair<std::vector<array>, std::vector<int>> Cholesky::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::cholesky(a, upper_, stream())}, {ax}};
}
std::vector<array> Concatenate::vjp( std::vector<array> Concatenate::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotangents, const std::vector<array>& cotangents,

View File

@ -870,7 +870,7 @@ class Equal : public UnaryPrimitive {
void print(std::ostream& os) override { void print(std::ostream& os) override {
if (equal_nan_) { if (equal_nan_) {
os << "NanEqual"; os << "NaNEqual";
} else { } else {
os << "Equal"; os << "Equal";
} }