mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
JIT compile option for binary minimization (#1091)
* try cpp 20 for compile * unary, binary, ternary in jit * nits * fix gather/scatter * fix rebase * reorg compile * add ternary to compile * jit copy * jit compile flag * fix build * use linked function for ternary * some nits * docs + circle min size build * docs + circle min size build * fix extension * fix no cpu build * improve includes
This commit is contained in:
parent
d568c7ee36
commit
226748b3e7
@ -114,7 +114,13 @@ jobs:
|
|||||||
name: Run CPP tests
|
name: Run CPP tests
|
||||||
command: |
|
command: |
|
||||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||||
DEVICE=cpu ./build/tests/tests
|
- run:
|
||||||
|
name: Build small binary
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
cd build/
|
||||||
|
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
|
||||||
|
make -j
|
||||||
|
|
||||||
build_release:
|
build_release:
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -20,6 +20,7 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
|||||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||||
|
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
|
||||||
if(NOT MLX_VERSION)
|
if(NOT MLX_VERSION)
|
||||||
@ -109,7 +110,7 @@ elseif (MLX_BUILD_METAL)
|
|||||||
$<INSTALL_INTERFACE:include/metal_cpp>
|
$<INSTALL_INTERFACE:include/metal_cpp>
|
||||||
)
|
)
|
||||||
target_link_libraries(
|
target_link_libraries(
|
||||||
mlx
|
mlx PUBLIC
|
||||||
${METAL_LIB}
|
${METAL_LIB}
|
||||||
${FOUNDATION_LIB}
|
${FOUNDATION_LIB}
|
||||||
${QUARTZ_LIB})
|
${QUARTZ_LIB})
|
||||||
@ -122,7 +123,7 @@ if (MLX_BUILD_CPU)
|
|||||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
set(MLX_BUILD_ACCELERATE ON)
|
set(MLX_BUILD_ACCELERATE ON)
|
||||||
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||||
else()
|
else()
|
||||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||||
@ -145,7 +146,7 @@ if (MLX_BUILD_CPU)
|
|||||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||||
# List blas after lapack otherwise we may accidentally incldue an old version
|
# List blas after lapack otherwise we may accidentally incldue an old version
|
||||||
# of lapack.h from the include dirs of blas.
|
# of lapack.h from the include dirs of blas.
|
||||||
find_package(BLAS REQUIRED)
|
find_package(BLAS REQUIRED)
|
||||||
@ -160,7 +161,7 @@ if (MLX_BUILD_CPU)
|
|||||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
set(MLX_BUILD_ACCELERATE OFF)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
@ -175,6 +176,14 @@ target_include_directories(
|
|||||||
$<INSTALL_INTERFACE:include>
|
$<INSTALL_INTERFACE:include>
|
||||||
)
|
)
|
||||||
|
|
||||||
|
FetchContent_Declare(fmt
|
||||||
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
|
GIT_TAG 10.2.1
|
||||||
|
EXCLUDE_FROM_ALL
|
||||||
|
)
|
||||||
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
|
||||||
|
|
||||||
if (MLX_BUILD_PYTHON_BINDINGS)
|
if (MLX_BUILD_PYTHON_BINDINGS)
|
||||||
message(STATUS "Building Python bindings.")
|
message(STATUS "Building Python bindings.")
|
||||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||||
|
@ -163,6 +163,8 @@ should point to the path to the built metal library.
|
|||||||
- ON
|
- ON
|
||||||
* - MLX_BUILD_GGUF
|
* - MLX_BUILD_GGUF
|
||||||
- ON
|
- ON
|
||||||
|
* - MLX_METAL_JIT
|
||||||
|
- OFF
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@ -196,9 +198,18 @@ GGUF, you can do:
|
|||||||
cmake ..
|
cmake ..
|
||||||
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
-DBUILD_SHARED_LIBS=ON \
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
-DMLX_BUILD_CPU=ON \
|
-DMLX_BUILD_CPU=OFF \
|
||||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||||
-DMLX_BUILD_GGUF=OFF
|
-DMLX_BUILD_GGUF=OFF
|
||||||
|
-DMLX_METAL_JIT=ON
|
||||||
|
|
||||||
|
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
|
||||||
|
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||||
|
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||||
|
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||||
|
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||||
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
|
Metal kernel cache persists accross reboots.
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
@ -98,12 +98,4 @@ void Cholesky::eval(const std::vector<array>& inputs, array& output) {
|
|||||||
cholesky_impl(inputs[0], output, upper_);
|
cholesky_impl(inputs[0], output, upper_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Cholesky::vmap(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<int>& axes) {
|
|
||||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
|
||||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
|
||||||
return {{linalg::cholesky(a, upper_, stream())}, {ax}};
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,33 +1,80 @@
|
|||||||
add_custom_command(
|
function(make_jit_source SRC_NAME)
|
||||||
OUTPUT compiled_preamble.cpp
|
# This function takes a metal header file,
|
||||||
|
# runs the C preprocessesor on it, and makes
|
||||||
|
# the processed contents available as a string in a C++ function
|
||||||
|
# mlx::core::metal::${SRC_NAME}()
|
||||||
|
#
|
||||||
|
# To use the function, declare it in jit/includes.h and
|
||||||
|
# include jit/includes.h.
|
||||||
|
#
|
||||||
|
# Additional arguments to this function are treated as dependencies
|
||||||
|
# in the Cmake build system.
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT jit/${SRC_NAME}.cpp
|
||||||
COMMAND /bin/bash
|
COMMAND /bin/bash
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
${CMAKE_CURRENT_BINARY_DIR}/jit
|
||||||
${CMAKE_C_COMPILER}
|
${CMAKE_C_COMPILER}
|
||||||
${PROJECT_SOURCE_DIR}
|
${PROJECT_SOURCE_DIR}
|
||||||
|
${SRC_NAME}
|
||||||
"-D${MLX_METAL_VERSION}"
|
"-D${MLX_METAL_VERSION}"
|
||||||
DEPENDS make_compiled_preamble.sh
|
DEPENDS make_compiled_preamble.sh
|
||||||
kernels/compiled_preamble.h
|
kernels/${SRC_NAME}.h
|
||||||
kernels/unary.h
|
${ARGN}
|
||||||
kernels/binary.h
|
)
|
||||||
|
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
||||||
|
add_dependencies(mlx ${SRC_NAME})
|
||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
|
||||||
|
)
|
||||||
|
endfunction(make_jit_source)
|
||||||
|
|
||||||
|
make_jit_source(
|
||||||
|
utils
|
||||||
kernels/bf16.h
|
kernels/bf16.h
|
||||||
|
kernels/complex.h
|
||||||
|
)
|
||||||
|
make_jit_source(
|
||||||
|
unary_ops
|
||||||
kernels/erf.h
|
kernels/erf.h
|
||||||
kernels/expm1f.h
|
kernels/expm1f.h
|
||||||
kernels/utils.h
|
|
||||||
kernels/bf16_math.h
|
|
||||||
)
|
)
|
||||||
|
make_jit_source(binary_ops)
|
||||||
add_custom_target(
|
make_jit_source(ternary_ops)
|
||||||
compiled_preamble
|
make_jit_source(
|
||||||
DEPENDS compiled_preamble.cpp
|
reduction
|
||||||
|
kernels/atomic.h
|
||||||
|
kernels/reduction/ops.h
|
||||||
)
|
)
|
||||||
|
make_jit_source(scatter)
|
||||||
|
make_jit_source(gather)
|
||||||
|
|
||||||
add_dependencies(mlx compiled_preamble)
|
if (MLX_METAL_JIT)
|
||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
|
||||||
|
)
|
||||||
|
make_jit_source(copy)
|
||||||
|
make_jit_source(unary)
|
||||||
|
make_jit_source(binary)
|
||||||
|
make_jit_source(binary_two)
|
||||||
|
make_jit_source(ternary)
|
||||||
|
else()
|
||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
@ -46,7 +93,8 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
if (NOT MLX_METAL_PATH)
|
if (NOT MLX_METAL_PATH)
|
||||||
|
322
mlx/backend/metal/binary.cpp
Normal file
322
mlx/backend/metal/binary.cpp
Normal file
@ -0,0 +1,322 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/binary.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||||
|
|
||||||
|
void binary_op(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
const std::string op) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||||
|
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||||
|
|
||||||
|
auto& out = outputs[0];
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to collapse contiguous dims
|
||||||
|
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||||
|
auto& strides_a = strides[0];
|
||||||
|
auto& strides_b = strides[1];
|
||||||
|
auto& strides_out = strides[2];
|
||||||
|
|
||||||
|
std::string kernel_name;
|
||||||
|
{
|
||||||
|
std::ostringstream kname;
|
||||||
|
switch (bopt) {
|
||||||
|
case BinaryOpType::ScalarScalar:
|
||||||
|
kname << "ss";
|
||||||
|
break;
|
||||||
|
case BinaryOpType::ScalarVector:
|
||||||
|
kname << "sv";
|
||||||
|
break;
|
||||||
|
case BinaryOpType::VectorScalar:
|
||||||
|
kname << "vs";
|
||||||
|
break;
|
||||||
|
case BinaryOpType::VectorVector:
|
||||||
|
kname << "vv";
|
||||||
|
break;
|
||||||
|
case BinaryOpType::General:
|
||||||
|
kname << "g";
|
||||||
|
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||||
|
kname << shape.size();
|
||||||
|
} else {
|
||||||
|
kname << "n";
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
kname << op << type_to_name(a);
|
||||||
|
kernel_name = kname.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
|
||||||
|
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// - If a is donated it goes to the first output
|
||||||
|
// - If b is donated it goes to the first output if a was not donated
|
||||||
|
// otherwise it goes to the second output
|
||||||
|
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||||
|
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||||
|
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
|
||||||
|
compute_encoder.set_input_array(
|
||||||
|
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
|
||||||
|
compute_encoder.set_output_array(outputs[0], 2);
|
||||||
|
compute_encoder.set_output_array(outputs[1], 3);
|
||||||
|
|
||||||
|
if (bopt == BinaryOpType::General) {
|
||||||
|
auto ndim = shape.size();
|
||||||
|
if (ndim > 3) {
|
||||||
|
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||||
|
} else {
|
||||||
|
// The shape is implicit in the grid for <= 3D
|
||||||
|
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||||
|
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch up to 3D grid of threads
|
||||||
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
|
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
|
size_t rest = out.size() / (dim0 * dim1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size != 1024) {
|
||||||
|
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||||
|
}
|
||||||
|
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||||
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
} else {
|
||||||
|
// Launch a 1D grid of threads
|
||||||
|
size_t nthreads = out.data_size();
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size > nthreads) {
|
||||||
|
thread_group_size = nthreads;
|
||||||
|
}
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void binary_op(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const std::string op) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt, true);
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to collapse contiguous dims
|
||||||
|
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||||
|
auto& strides_a = strides[0];
|
||||||
|
auto& strides_b = strides[1];
|
||||||
|
auto& strides_out = strides[2];
|
||||||
|
|
||||||
|
std::string kernel_name;
|
||||||
|
{
|
||||||
|
std::ostringstream kname;
|
||||||
|
switch (bopt) {
|
||||||
|
case BinaryOpType::ScalarScalar:
|
||||||
|
kname << "ss";
|
||||||
|
break;
|
||||||
|
case BinaryOpType::ScalarVector:
|
||||||
|
kname << "sv";
|
||||||
|
break;
|
||||||
|
case BinaryOpType::VectorScalar:
|
||||||
|
kname << "vs";
|
||||||
|
break;
|
||||||
|
case BinaryOpType::VectorVector:
|
||||||
|
kname << "vv";
|
||||||
|
break;
|
||||||
|
case BinaryOpType::General:
|
||||||
|
kname << "g";
|
||||||
|
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||||
|
kname << shape.size();
|
||||||
|
} else {
|
||||||
|
kname << "n";
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
kname << op << type_to_name(a);
|
||||||
|
kernel_name = kname.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
auto kernel = get_binary_kernel(d, kernel_name, a, out);
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||||
|
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||||
|
compute_encoder.set_input_array(donate_a ? out : a, 0);
|
||||||
|
compute_encoder.set_input_array(donate_b ? out : b, 1);
|
||||||
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
|
if (bopt == BinaryOpType::General) {
|
||||||
|
auto ndim = shape.size();
|
||||||
|
if (ndim > 3) {
|
||||||
|
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||||
|
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||||
|
} else {
|
||||||
|
// The shape is implicit in the grid for <= 3D
|
||||||
|
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
|
||||||
|
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch up to 3D grid of threads
|
||||||
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
|
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
|
size_t rest = out.size() / (dim0 * dim1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size != 1024) {
|
||||||
|
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||||
|
}
|
||||||
|
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||||
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
} else {
|
||||||
|
// Launch a 1D grid of threads
|
||||||
|
size_t nthreads =
|
||||||
|
bopt == BinaryOpType::General ? out.size() : out.data_size();
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size > nthreads) {
|
||||||
|
thread_group_size = nthreads;
|
||||||
|
}
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "add");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "arctan2");
|
||||||
|
}
|
||||||
|
|
||||||
|
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
switch (op_) {
|
||||||
|
case BitwiseBinary::And:
|
||||||
|
binary_op(inputs, out, "bitwise_and");
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::Or:
|
||||||
|
binary_op(inputs, out, "bitwise_or");
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::Xor:
|
||||||
|
binary_op(inputs, out, "bitwise_xor");
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::LeftShift:
|
||||||
|
binary_op(inputs, out, "left_shift");
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::RightShift:
|
||||||
|
binary_op(inputs, out, "right_shift");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "div");
|
||||||
|
}
|
||||||
|
|
||||||
|
void DivMod::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
binary_op(inputs, outputs, "divmod");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "rem");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "ge");
|
||||||
|
}
|
||||||
|
|
||||||
|
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "geq");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "le");
|
||||||
|
}
|
||||||
|
|
||||||
|
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "leq");
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "land");
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "lor");
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "lae");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "max");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "min");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "mul");
|
||||||
|
}
|
||||||
|
|
||||||
|
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "neq");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "pow");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
binary_op(inputs, out, "sub");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -4,8 +4,8 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/compiled_preamble.h"
|
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/graph_utils.h"
|
#include "mlx/graph_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@ -190,7 +190,8 @@ void Compiled::eval_gpu(
|
|||||||
// If not we have to build it ourselves
|
// If not we have to build it ourselves
|
||||||
if (lib == nullptr) {
|
if (lib == nullptr) {
|
||||||
std::ostringstream kernel;
|
std::ostringstream kernel;
|
||||||
kernel << metal::get_kernel_preamble() << std::endl;
|
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
|
||||||
|
<< metal::ternary_ops();
|
||||||
build_kernel(
|
build_kernel(
|
||||||
kernel,
|
kernel,
|
||||||
kernel_lib_ + "_contiguous",
|
kernel_lib_ + "_contiguous",
|
||||||
|
@ -1,9 +0,0 @@
|
|||||||
// Copyright © 2023-24 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
|
||||||
|
|
||||||
const char* get_kernel_preamble();
|
|
||||||
|
|
||||||
}
|
|
@ -4,12 +4,14 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/metal/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
// If the input is donateable, we are doing a vector copy and the types
|
// If the input is donateable, we are doing a vector copy and the types
|
||||||
@ -62,27 +64,34 @@ void copy_gpu_inplace(
|
|||||||
auto& strides_out_ = strides[1];
|
auto& strides_out_ = strides[1];
|
||||||
|
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
std::string kernel_name;
|
||||||
|
{
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
switch (ctype) {
|
switch (ctype) {
|
||||||
case CopyType::Scalar:
|
case CopyType::Scalar:
|
||||||
kname << "scopy";
|
kname << "s";
|
||||||
break;
|
break;
|
||||||
case CopyType::Vector:
|
case CopyType::Vector:
|
||||||
kname << "vcopy";
|
kname << "v";
|
||||||
break;
|
break;
|
||||||
case CopyType::General:
|
case CopyType::General:
|
||||||
kname << "gcopy";
|
kname << "g";
|
||||||
break;
|
break;
|
||||||
case CopyType::GeneralGeneral:
|
case CopyType::GeneralGeneral:
|
||||||
kname << "ggcopy";
|
kname << "gg";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
kname << type_to_name(in) << type_to_name(out);
|
|
||||||
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
|
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
|
||||||
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||||
kname << "_" << shape.size();
|
kname << shape.size();
|
||||||
}
|
}
|
||||||
auto kernel = d.get_kernel(kname.str());
|
kname << "_copy";
|
||||||
|
kname << type_to_name(in) << type_to_name(out);
|
||||||
|
kernel_name = kname.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel = get_copy_kernel(d, kernel_name, in, out);
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||||
@ -106,7 +115,7 @@ void copy_gpu_inplace(
|
|||||||
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -285,7 +285,6 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
|
|||||||
NS::Error* error = nullptr;
|
NS::Error* error = nullptr;
|
||||||
auto options = MTL::CompileOptions::alloc()->init();
|
auto options = MTL::CompileOptions::alloc()->init();
|
||||||
options->setFastMathEnabled(false);
|
options->setFastMathEnabled(false);
|
||||||
|
|
||||||
options->setLanguageVersion(get_metal_version());
|
options->setLanguageVersion(get_metal_version());
|
||||||
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
|
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
|
||||||
options->release();
|
options->release();
|
||||||
|
@ -1,24 +1,35 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <algorithm>
|
#include <fmt/format.h>
|
||||||
#include <cassert>
|
|
||||||
#include <numeric>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/metal/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
|
#include "mlx/backend/metal/jit/indexing.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
constexpr int METAL_MAX_INDEX_ARRAYS = 20;
|
||||||
|
|
||||||
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
std::pair<std::string, std::string> make_index_args(
|
||||||
|
const std::string& idx_type,
|
||||||
} // namespace
|
int nidx) {
|
||||||
|
std::ostringstream idx_args;
|
||||||
|
std::ostringstream idx_arr;
|
||||||
|
for (int i = 0; i < nidx; ++i) {
|
||||||
|
idx_args << fmt::format(
|
||||||
|
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
|
||||||
|
idx_arr << fmt::format("idx{0}", i);
|
||||||
|
if (i < nidx - 1) {
|
||||||
|
idx_args << "\n";
|
||||||
|
idx_arr << ",";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {idx_args.str(), idx_arr.str()};
|
||||||
|
}
|
||||||
|
|
||||||
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
auto& src = inputs[0];
|
auto& src = inputs[0];
|
||||||
@ -42,15 +53,41 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||||
size_t ndim = src.ndim();
|
size_t ndim = src.ndim();
|
||||||
|
|
||||||
std::ostringstream kname;
|
std::string lib_name;
|
||||||
|
std::string kernel_name;
|
||||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||||
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
|
{
|
||||||
if (idx_ndim <= 1) {
|
std::ostringstream kname;
|
||||||
kname << "_" << idx_ndim;
|
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
|
||||||
|
<< "_" << idx_ndim;
|
||||||
|
lib_name = kname.str();
|
||||||
|
kernel_name = lib_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << metal::utils() << metal::gather();
|
||||||
|
std::string out_type_str = get_type_string(out.dtype());
|
||||||
|
std::string idx_type_str =
|
||||||
|
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
||||||
|
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||||
|
|
||||||
|
// Index dimension specializations
|
||||||
|
kernel_source << fmt::format(
|
||||||
|
gather_kernels,
|
||||||
|
type_to_name(out) + idx_type_name,
|
||||||
|
out_type_str,
|
||||||
|
idx_type_str,
|
||||||
|
nidx,
|
||||||
|
idx_args,
|
||||||
|
idx_arr,
|
||||||
|
idx_ndim);
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kernel_name, lib);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
size_t slice_size = 1;
|
size_t slice_size = 1;
|
||||||
@ -102,8 +139,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
||||||
|
|
||||||
// Set index buffers
|
// Set index buffers
|
||||||
for (int i = 1; i < nidx + 1; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch grid
|
// Launch grid
|
||||||
@ -139,10 +176,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
// Get kernel name
|
|
||||||
std::ostringstream kname;
|
|
||||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
|
||||||
|
|
||||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||||
bool index_nd1_specialization = (idx_ndim == 1);
|
bool index_nd1_specialization = (idx_ndim == 1);
|
||||||
|
|
||||||
@ -159,32 +192,85 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
index_nd1_specialization &= inputs[i].flags().row_contiguous;
|
index_nd1_specialization &= inputs[i].flags().row_contiguous;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string lib_name;
|
||||||
|
std::string kernel_name;
|
||||||
|
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||||
|
std::string op_name;
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Scatter::None:
|
||||||
|
op_name = "none";
|
||||||
|
break;
|
||||||
|
case Scatter::Sum:
|
||||||
|
op_name = "sum";
|
||||||
|
break;
|
||||||
|
case Scatter::Prod:
|
||||||
|
op_name = "prod";
|
||||||
|
break;
|
||||||
|
case Scatter::Max:
|
||||||
|
op_name = "max";
|
||||||
|
break;
|
||||||
|
case Scatter::Min:
|
||||||
|
op_name = "min";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::ostringstream kname;
|
||||||
if (index_nd1_specialization) {
|
if (index_nd1_specialization) {
|
||||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||||
} else {
|
} else {
|
||||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||||
}
|
}
|
||||||
|
kname << "_" << op_name << "_" << nidx;
|
||||||
|
lib_name = kname.str();
|
||||||
|
kernel_name = kname.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << metal::utils() << metal::reduction() << metal::scatter();
|
||||||
|
|
||||||
|
std::string out_type_str = get_type_string(out.dtype());
|
||||||
|
std::string idx_type_str =
|
||||||
|
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
||||||
|
std::string op_type;
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Scatter::None:
|
case Scatter::None:
|
||||||
kname << "_none";
|
op_type = "None";
|
||||||
break;
|
break;
|
||||||
case Scatter::Sum:
|
case Scatter::Sum:
|
||||||
kname << "_sum";
|
op_type = "Sum<{0}>";
|
||||||
break;
|
break;
|
||||||
case Scatter::Prod:
|
case Scatter::Prod:
|
||||||
kname << "_prod";
|
op_type = "Prod<{0}>";
|
||||||
break;
|
break;
|
||||||
case Scatter::Max:
|
case Scatter::Max:
|
||||||
kname << "_max";
|
op_type = "Max<{0}>";
|
||||||
break;
|
break;
|
||||||
case Scatter::Min:
|
case Scatter::Min:
|
||||||
kname << "_min";
|
op_type = "Min<{0}>";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
kname << "_" << nidx;
|
if (reduce_type_ != Scatter::None) {
|
||||||
|
op_type = fmt::format(op_type, out_type_str);
|
||||||
|
}
|
||||||
|
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||||
|
|
||||||
|
kernel_source << fmt::format(
|
||||||
|
scatter_kernels,
|
||||||
|
type_to_name(out) + idx_type_name + "_" + op_name,
|
||||||
|
out_type_str,
|
||||||
|
idx_type_str,
|
||||||
|
op_type,
|
||||||
|
nidx,
|
||||||
|
idx_args,
|
||||||
|
idx_arr);
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kernel_name, lib);
|
||||||
|
|
||||||
auto& upd = inputs.back();
|
auto& upd = inputs.back();
|
||||||
size_t nthreads = upd.size();
|
size_t nthreads = upd.size();
|
||||||
@ -209,8 +295,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
||||||
|
|
||||||
// Set index buffers
|
// Set index buffers
|
||||||
for (int i = 1; i < nidx + 1; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch grid
|
// Launch grid
|
||||||
@ -279,8 +365,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||||
|
|
||||||
// Set index buffers
|
// Set index buffers
|
||||||
for (int i = 1; i < nidx + 1; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch grid
|
// Launch grid
|
||||||
|
87
mlx/backend/metal/jit/binary.h
Normal file
87
mlx/backend/metal/jit/binary.h
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
constexpr std::string_view binary_kernels = R"(
|
||||||
|
template [[host_name("ss{0}")]] [[kernel]]
|
||||||
|
void binary_ss<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("vs{0}")]] [[kernel]]
|
||||||
|
void binary_vs<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("sv{0}")]] [[kernel]]
|
||||||
|
void binary_sv<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("vv{0}")]] [[kernel]]
|
||||||
|
void binary_vv<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("g4{0}")]] [[kernel]] void
|
||||||
|
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
constant const int shape[4],
|
||||||
|
constant const size_t a_strides[4],
|
||||||
|
constant const size_t b_strides[4],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
template [[host_name("g5{0}")]] [[kernel]] void
|
||||||
|
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
constant const int shape[5],
|
||||||
|
constant const size_t a_strides[5],
|
||||||
|
constant const size_t b_strides[5],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
template [[host_name("g1{0}")]] [[kernel]] void
|
||||||
|
binary_g_nd1<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
constant const size_t& a_stride,
|
||||||
|
constant const size_t& b_stride,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("g2{0}")]] [[kernel]] void
|
||||||
|
binary_g_nd2<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
constant const size_t a_strides[2],
|
||||||
|
constant const size_t b_strides[2],
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]);
|
||||||
|
template [[host_name("g3{0}")]] [[kernel]] void
|
||||||
|
binary_g_nd3<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
constant const size_t a_strides[3],
|
||||||
|
constant const size_t b_strides[3],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
template [[host_name("gn{0}")]] [[kernel]]
|
||||||
|
void binary_g<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* a_strides,
|
||||||
|
constant const size_t* b_strides,
|
||||||
|
constant const int& ndim,
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
)";
|
98
mlx/backend/metal/jit/binary_two.h
Normal file
98
mlx/backend/metal/jit/binary_two.h
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
constexpr std::string_view binary_two_kernels = R"(
|
||||||
|
template [[host_name("ss{0}")]] [[kernel]]
|
||||||
|
void binary_ss<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
device {2}* d,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("vs{0}")]] [[kernel]]
|
||||||
|
void binary_vs<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
device {2}* d,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("sv{0}")]] [[kernel]]
|
||||||
|
void binary_sv<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
device {2}* d,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("vv{0}")]] [[kernel]]
|
||||||
|
void binary_vv<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
device {2}* d,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
template [[host_name("g4{0}")]] [[kernel]] void
|
||||||
|
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
device {2}* d,
|
||||||
|
constant const int shape[4],
|
||||||
|
constant const size_t a_strides[4],
|
||||||
|
constant const size_t b_strides[4],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
template [[host_name("g5{0}")]] [[kernel]] void
|
||||||
|
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
device {2}* d,
|
||||||
|
constant const int shape[5],
|
||||||
|
constant const size_t a_strides[5],
|
||||||
|
constant const size_t b_strides[5],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
template [[host_name("g1{0}")]] [[kernel]] void
|
||||||
|
binary_g_nd1<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
device {2}* d,
|
||||||
|
constant const size_t& a_stride,
|
||||||
|
constant const size_t& b_stride,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("g2{0}")]] [[kernel]] void
|
||||||
|
binary_g_nd2<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
device {2}* d,
|
||||||
|
constant const size_t a_strides[2],
|
||||||
|
constant const size_t b_strides[2],
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]);
|
||||||
|
template [[host_name("g3{0}")]] [[kernel]] void
|
||||||
|
binary_g_nd3<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
device {2}* d,
|
||||||
|
constant const size_t a_strides[3],
|
||||||
|
constant const size_t b_strides[3],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
template [[host_name("gn{0}")]] [[kernel]]
|
||||||
|
void binary_g<{1}, {2}, {3}>(
|
||||||
|
device const {1}* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device {2}* c,
|
||||||
|
device {2}* d,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* a_strides,
|
||||||
|
constant const size_t* b_strides,
|
||||||
|
constant const int& ndim,
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
)";
|
100
mlx/backend/metal/jit/copy.h
Normal file
100
mlx/backend/metal/jit/copy.h
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
constexpr std::string_view copy_kernels = R"(
|
||||||
|
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||||
|
copy_g_nd<{1}, {2}, 4>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
constant const int* src_shape [[buffer(2)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
template [[host_name("gg4_{0}")]] [[kernel]] void
|
||||||
|
copy_gg_nd<{1}, {2}, 4>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
constant const int* src_shape [[buffer(2)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||||
|
copy_g_nd<{1}, {2}, 5>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
constant const int* src_shape [[buffer(2)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
template [[host_name("gg5_{0}")]] [[kernel]] void
|
||||||
|
copy_gg_nd<{1}, {2}, 5>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
constant const int* src_shape [[buffer(2)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
constant const int64_t& src_stride [[buffer(3)]],
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]);
|
||||||
|
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
template [[host_name("gg1_{0}")]] [[kernel]] void
|
||||||
|
copy_gg_nd1<{1}, {2}>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
constant const int64_t& src_stride [[buffer(3)]],
|
||||||
|
constant const int64_t& dst_stride [[buffer(4)]],
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("gg2_{0}")]] [[kernel]] void
|
||||||
|
copy_gg_nd2<{1}, {2}>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
|
uint2 index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("gg3_{0}")]] [[kernel]] void
|
||||||
|
copy_gg_nd3<{1}, {2}>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
constant const int* src_shape [[buffer(2)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
constant const int& ndim [[buffer(5)]],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
|
||||||
|
device const {1}* src [[buffer(0)]],
|
||||||
|
device {2}* dst [[buffer(1)]],
|
||||||
|
constant const int* src_shape [[buffer(2)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
|
constant const int& ndim [[buffer(5)]],
|
||||||
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
)";
|
21
mlx/backend/metal/jit/includes.h
Normal file
21
mlx/backend/metal/jit/includes.h
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
// Copyright © 2023-24 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
|
const char* utils();
|
||||||
|
const char* binary_ops();
|
||||||
|
const char* unary_ops();
|
||||||
|
const char* ternary_ops();
|
||||||
|
const char* reduction();
|
||||||
|
const char* gather();
|
||||||
|
const char* scatter();
|
||||||
|
|
||||||
|
const char* unary();
|
||||||
|
const char* binary();
|
||||||
|
const char* binary_two();
|
||||||
|
const char* copy();
|
||||||
|
const char* ternary();
|
||||||
|
|
||||||
|
} // namespace mlx::core::metal
|
81
mlx/backend/metal/jit/indexing.h
Normal file
81
mlx/backend/metal/jit/indexing.h
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
constexpr std::string_view gather_kernels = R"(
|
||||||
|
[[kernel]] void gather{0}_{3}_{6}(
|
||||||
|
const device {1}* src [[buffer(0)]],
|
||||||
|
device {1}* out [[buffer(1)]],
|
||||||
|
const constant int* src_shape [[buffer(2)]],
|
||||||
|
const constant size_t* src_strides [[buffer(3)]],
|
||||||
|
const constant size_t& src_ndim [[buffer(4)]],
|
||||||
|
const constant int* slice_sizes [[buffer(5)]],
|
||||||
|
const constant int* axes [[buffer(6)]],
|
||||||
|
const constant int* idx_shapes [[buffer(7)]],
|
||||||
|
const constant size_t* idx_strides [[buffer(8)]],
|
||||||
|
const constant int& idx_ndim [[buffer(9)]],
|
||||||
|
{4}
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {{
|
||||||
|
Indices<{2}, {3}> idxs{{
|
||||||
|
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||||
|
|
||||||
|
return gather_impl<{1}, {2}, {3}, {6}>(
|
||||||
|
src,
|
||||||
|
out,
|
||||||
|
src_shape,
|
||||||
|
src_strides,
|
||||||
|
src_ndim,
|
||||||
|
slice_sizes,
|
||||||
|
axes,
|
||||||
|
idxs,
|
||||||
|
index,
|
||||||
|
grid_dim);
|
||||||
|
}}
|
||||||
|
)";
|
||||||
|
|
||||||
|
constexpr std::string_view scatter_kernels = R"(
|
||||||
|
[[kernel]] void scatter_1d_index{0}_{4}(
|
||||||
|
const device {1}* updates [[buffer(1)]],
|
||||||
|
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||||
|
const constant int* out_shape [[buffer(3)]],
|
||||||
|
const constant size_t* out_strides [[buffer(4)]],
|
||||||
|
const constant size_t& upd_size [[buffer(5)]],
|
||||||
|
{5}
|
||||||
|
uint2 gid [[thread_position_in_grid]]) {{
|
||||||
|
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
||||||
|
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
||||||
|
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
|
||||||
|
}}
|
||||||
|
|
||||||
|
[[kernel]] void scatter{0}_{4}(
|
||||||
|
const device {1}* updates [[buffer(1)]],
|
||||||
|
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||||
|
const constant int* upd_shape [[buffer(3)]],
|
||||||
|
const constant size_t* upd_strides [[buffer(4)]],
|
||||||
|
const constant size_t& upd_ndim [[buffer(5)]],
|
||||||
|
const constant size_t& upd_size [[buffer(6)]],
|
||||||
|
const constant int* out_shape [[buffer(7)]],
|
||||||
|
const constant size_t* out_strides [[buffer(8)]],
|
||||||
|
const constant size_t& out_ndim [[buffer(9)]],
|
||||||
|
const constant int* axes [[buffer(10)]],
|
||||||
|
const constant int* idx_shapes [[buffer(11)]],
|
||||||
|
const constant size_t* idx_strides [[buffer(12)]],
|
||||||
|
const constant int& idx_ndim [[buffer(13)]],
|
||||||
|
{5}
|
||||||
|
uint2 gid [[thread_position_in_grid]]) {{
|
||||||
|
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||||
|
|
||||||
|
return scatter_impl<{1}, {2}, {3}, {4}>(
|
||||||
|
updates,
|
||||||
|
out,
|
||||||
|
upd_shape,
|
||||||
|
upd_strides,
|
||||||
|
upd_ndim,
|
||||||
|
upd_size,
|
||||||
|
out_shape,
|
||||||
|
out_strides,
|
||||||
|
out_ndim,
|
||||||
|
axes,
|
||||||
|
idxs,
|
||||||
|
gid);
|
||||||
|
}}
|
||||||
|
)";
|
80
mlx/backend/metal/jit/ternary.h
Normal file
80
mlx/backend/metal/jit/ternary.h
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
constexpr std::string_view ternary_kernels = R"(
|
||||||
|
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
|
||||||
|
device const bool* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device const {1}* c,
|
||||||
|
device {1}* d,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
|
||||||
|
device const bool* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device const {1}* c,
|
||||||
|
device {1}* d,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* a_strides,
|
||||||
|
constant const size_t* b_strides,
|
||||||
|
constant const size_t* c_strides,
|
||||||
|
constant const int& ndim,
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
|
template [[host_name("g1_{0}")]] [[kernel]] void
|
||||||
|
ternary_g_nd1<{1}, {2}>(
|
||||||
|
device const bool* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device const {1}* c,
|
||||||
|
device {1}* d,
|
||||||
|
constant const size_t& a_strides,
|
||||||
|
constant const size_t& b_strides,
|
||||||
|
constant const size_t& c_strides,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
template [[host_name("g2_{0}")]] [[kernel]] void
|
||||||
|
ternary_g_nd2<{1}, {2}>(
|
||||||
|
device const bool* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device const {1}* c,
|
||||||
|
device {1}* d,
|
||||||
|
constant const size_t a_strides[2],
|
||||||
|
constant const size_t b_strides[2],
|
||||||
|
constant const size_t c_strides[2],
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]);
|
||||||
|
template [[host_name("g3_{0}")]] [[kernel]] void
|
||||||
|
ternary_g_nd3<{1}, {2}>(
|
||||||
|
device const bool* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device const {1}* c,
|
||||||
|
device {1}* d,
|
||||||
|
constant const size_t a_strides[3],
|
||||||
|
constant const size_t b_strides[3],
|
||||||
|
constant const size_t c_strides[3],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||||
|
ternary_g_nd<{1}, {2}, 4>(
|
||||||
|
device const bool* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device const {1}* c,
|
||||||
|
device {1}* d,
|
||||||
|
constant const int shape[4],
|
||||||
|
constant const size_t a_strides[4],
|
||||||
|
constant const size_t b_strides[4],
|
||||||
|
constant const size_t c_strides[4],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||||
|
ternary_g_nd<{1}, {2}, 5>(
|
||||||
|
device const bool* a,
|
||||||
|
device const {1}* b,
|
||||||
|
device const {1}* c,
|
||||||
|
device {1}* d,
|
||||||
|
constant const int shape[5],
|
||||||
|
constant const size_t a_strides[5],
|
||||||
|
constant const size_t b_strides[5],
|
||||||
|
constant const size_t c_strides[5],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
)";
|
16
mlx/backend/metal/jit/unary.h
Normal file
16
mlx/backend/metal/jit/unary.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
constexpr std::string_view unary_kernels = R"(
|
||||||
|
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
|
||||||
|
device const {1}* in,
|
||||||
|
device {1}* out,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
|
||||||
|
device const {1}* in,
|
||||||
|
device {1}* out,
|
||||||
|
device const int* in_shape,
|
||||||
|
device const size_t* in_strides,
|
||||||
|
device const int& ndim,
|
||||||
|
uint index [[thread_position_in_grid]]);
|
||||||
|
)";
|
124
mlx/backend/metal/jit_kernels.cpp
Normal file
124
mlx/backend/metal/jit_kernels.cpp
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <fmt/format.h>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/compiled.h"
|
||||||
|
#include "mlx/backend/metal/jit/binary.h"
|
||||||
|
#include "mlx/backend/metal/jit/binary_two.h"
|
||||||
|
#include "mlx/backend/metal/jit/copy.h"
|
||||||
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
|
#include "mlx/backend/metal/jit/ternary.h"
|
||||||
|
#include "mlx/backend/metal/jit/unary.h"
|
||||||
|
#include "mlx/backend/metal/kernels.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
std::string op_name(const array& arr) {
|
||||||
|
std::ostringstream op_t;
|
||||||
|
arr.primitive().print(op_t);
|
||||||
|
return op_t.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_unary_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out) {
|
||||||
|
std::string lib_name = kernel_name.substr(1);
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
|
||||||
|
<< fmt::format(
|
||||||
|
unary_kernels,
|
||||||
|
lib_name,
|
||||||
|
get_type_string(out.dtype()),
|
||||||
|
op_name(out));
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_binary_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out) {
|
||||||
|
std::string lib_name = kernel_name.substr(2);
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << metal::utils() << metal::binary_ops() << metal::binary()
|
||||||
|
<< fmt::format(
|
||||||
|
binary_kernels,
|
||||||
|
lib_name,
|
||||||
|
get_type_string(in.dtype()),
|
||||||
|
get_type_string(out.dtype()),
|
||||||
|
op_name(out));
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out) {
|
||||||
|
std::string lib_name = kernel_name.substr(2);
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << metal::utils() << metal::binary_ops()
|
||||||
|
<< metal::binary_two()
|
||||||
|
<< fmt::format(
|
||||||
|
binary_two_kernels,
|
||||||
|
lib_name,
|
||||||
|
get_type_string(in.dtype()),
|
||||||
|
get_type_string(out.dtype()),
|
||||||
|
op_name(out));
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_ternary_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out) {
|
||||||
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary()
|
||||||
|
<< fmt::format(
|
||||||
|
ternary_kernels,
|
||||||
|
lib_name,
|
||||||
|
get_type_string(out.dtype()),
|
||||||
|
op_name(out));
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_copy_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out) {
|
||||||
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << metal::utils() << metal::copy()
|
||||||
|
<< fmt::format(
|
||||||
|
copy_kernels,
|
||||||
|
lib_name,
|
||||||
|
get_type_string(in.dtype()),
|
||||||
|
get_type_string(out.dtype()));
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
36
mlx/backend/metal/kernels.h
Normal file
36
mlx/backend/metal/kernels.h
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_unary_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_binary_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_ternary_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_copy_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -3,13 +3,8 @@ set(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
|
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.h
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -17,10 +12,7 @@ set(
|
|||||||
KERNELS
|
KERNELS
|
||||||
"arange"
|
"arange"
|
||||||
"arg_reduce"
|
"arg_reduce"
|
||||||
"binary"
|
|
||||||
"binary_two"
|
|
||||||
"conv"
|
"conv"
|
||||||
"copy"
|
|
||||||
"fft"
|
"fft"
|
||||||
"gemv"
|
"gemv"
|
||||||
"quantized"
|
"quantized"
|
||||||
@ -32,12 +24,30 @@ set(
|
|||||||
"scaled_dot_product_attention"
|
"scaled_dot_product_attention"
|
||||||
"softmax"
|
"softmax"
|
||||||
"sort"
|
"sort"
|
||||||
"ternary"
|
|
||||||
"unary"
|
|
||||||
"gather"
|
|
||||||
"scatter"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (NOT MLX_METAL_JIT)
|
||||||
|
set(
|
||||||
|
KERNELS
|
||||||
|
${KERNELS}
|
||||||
|
"binary"
|
||||||
|
"binary_two"
|
||||||
|
"unary"
|
||||||
|
"ternary"
|
||||||
|
"copy"
|
||||||
|
)
|
||||||
|
set(
|
||||||
|
HEADERS
|
||||||
|
${HEADERS}
|
||||||
|
unary_ops.h
|
||||||
|
unary.h
|
||||||
|
binary_ops.h
|
||||||
|
binary.h
|
||||||
|
ternary.h
|
||||||
|
copy.h
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
|
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#ifndef MLX_METAL_JIT
|
||||||
#include <metal_atomic>
|
#include <metal_atomic>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
@ -1,273 +1,113 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_ss(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
c[index] = Op()(a[0], b[0]);
|
||||||
|
}
|
||||||
|
|
||||||
#include <metal_integer>
|
template <typename T, typename U, typename Op>
|
||||||
#include <metal_math>
|
[[kernel]] void binary_sv(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
c[index] = Op()(a[0], b[index]);
|
||||||
|
}
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
template <typename T, typename U, typename Op>
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
[[kernel]] void binary_vs(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
c[index] = Op()(a[index], b[0]);
|
||||||
|
}
|
||||||
|
|
||||||
struct Add {
|
template <typename T, typename U, typename Op>
|
||||||
template <typename T>
|
[[kernel]] void binary_vv(
|
||||||
T operator()(T x, T y) {
|
device const T* a,
|
||||||
return x + y;
|
device const T* b,
|
||||||
}
|
device U* c,
|
||||||
};
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
c[index] = Op()(a[index], b[index]);
|
||||||
|
}
|
||||||
|
|
||||||
struct Divide {
|
template <typename T, typename U, typename Op>
|
||||||
template <typename T>
|
[[kernel]] void binary_g_nd1(
|
||||||
T operator()(T x, T y) {
|
device const T* a,
|
||||||
return x / y;
|
device const T* b,
|
||||||
}
|
device U* c,
|
||||||
};
|
constant const size_t& a_stride,
|
||||||
|
constant const size_t& b_stride,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||||
|
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||||
|
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
struct Remainder {
|
template <typename T, typename U, typename Op>
|
||||||
template <typename T>
|
[[kernel]] void binary_g_nd2(
|
||||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
device const T* a,
|
||||||
operator()(T x, T y) {
|
device const T* b,
|
||||||
return x % y;
|
device U* c,
|
||||||
}
|
constant const size_t a_strides[2],
|
||||||
template <typename T>
|
constant const size_t b_strides[2],
|
||||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
uint2 index [[thread_position_in_grid]],
|
||||||
operator()(T x, T y) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto r = x % y;
|
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||||
if (r != 0 && (r < 0 != y < 0)) {
|
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||||
r += y;
|
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||||
}
|
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||||
return r;
|
}
|
||||||
}
|
|
||||||
template <typename T>
|
|
||||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
|
||||||
T r = fmod(x, y);
|
|
||||||
if (r != 0 && (r < 0 != y < 0)) {
|
|
||||||
r += y;
|
|
||||||
}
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
|
||||||
return x % y;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Equal {
|
template <typename T, typename U, typename Op>
|
||||||
template <typename T>
|
[[kernel]] void binary_g_nd3(
|
||||||
bool operator()(T x, T y) {
|
device const T* a,
|
||||||
return x == y;
|
device const T* b,
|
||||||
}
|
device U* c,
|
||||||
};
|
constant const size_t a_strides[3],
|
||||||
|
constant const size_t b_strides[3],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||||
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
|
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
struct NaNEqual {
|
template <typename T, typename U, typename Op, int DIM>
|
||||||
template <typename T>
|
[[kernel]] void binary_g_nd(
|
||||||
bool operator()(T x, T y) {
|
device const T* a,
|
||||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
device const T* b,
|
||||||
}
|
device U* c,
|
||||||
template <>
|
constant const int shape[DIM],
|
||||||
bool operator()(complex64_t x, complex64_t y) {
|
constant const size_t a_strides[DIM],
|
||||||
return x == y ||
|
constant const size_t b_strides[DIM],
|
||||||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
|
uint3 index [[thread_position_in_grid]],
|
||||||
metal::isnan(y.imag)) ||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
size_t out_idx =
|
||||||
}
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
};
|
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||||
|
}
|
||||||
|
|
||||||
struct Greater {
|
template <typename T, typename U, typename Op>
|
||||||
template <typename T>
|
[[kernel]] void binary_g(
|
||||||
bool operator()(T x, T y) {
|
device const T* a,
|
||||||
return x > y;
|
device const T* b,
|
||||||
}
|
device U* c,
|
||||||
};
|
constant const int* shape,
|
||||||
|
constant const size_t* a_strides,
|
||||||
struct GreaterEqual {
|
constant const size_t* b_strides,
|
||||||
template <typename T>
|
constant const int& ndim,
|
||||||
bool operator()(T x, T y) {
|
uint3 index [[thread_position_in_grid]],
|
||||||
return x >= y;
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
}
|
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||||
};
|
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||||
|
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||||
struct Less {
|
}
|
||||||
template <typename T>
|
|
||||||
bool operator()(T x, T y) {
|
|
||||||
return x < y;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct LessEqual {
|
|
||||||
template <typename T>
|
|
||||||
bool operator()(T x, T y) {
|
|
||||||
return x <= y;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct LogAddExp {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x, T y) {
|
|
||||||
if (metal::isnan(x) || metal::isnan(y)) {
|
|
||||||
return metal::numeric_limits<T>::quiet_NaN();
|
|
||||||
}
|
|
||||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
|
||||||
T maxval = metal::max(x, y);
|
|
||||||
T minval = metal::min(x, y);
|
|
||||||
return (minval == -inf || maxval == inf)
|
|
||||||
? maxval
|
|
||||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Maximum {
|
|
||||||
template <typename T>
|
|
||||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
|
||||||
return metal::max(x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
|
||||||
if (metal::isnan(x)) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
return x > y ? x : y;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
|
||||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
return x > y ? x : y;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Minimum {
|
|
||||||
template <typename T>
|
|
||||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
|
||||||
return metal::min(x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
|
||||||
if (metal::isnan(x)) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
return x < y ? x : y;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
|
||||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
return x < y ? x : y;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Multiply {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x, T y) {
|
|
||||||
return x * y;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct NotEqual {
|
|
||||||
template <typename T>
|
|
||||||
bool operator()(T x, T y) {
|
|
||||||
return x != y;
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
bool operator()(complex64_t x, complex64_t y) {
|
|
||||||
return x.real != y.real || x.imag != y.imag;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Power {
|
|
||||||
template <typename T>
|
|
||||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
|
||||||
return metal::pow(base, exp);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
|
||||||
T res = 1;
|
|
||||||
while (exp) {
|
|
||||||
if (exp & 1) {
|
|
||||||
res *= base;
|
|
||||||
}
|
|
||||||
exp >>= 1;
|
|
||||||
base *= base;
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
|
||||||
auto x_theta = metal::atan(x.imag / x.real);
|
|
||||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
|
||||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
|
||||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
|
||||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Subtract {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x, T y) {
|
|
||||||
return x - y;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct LogicalAnd {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x, T y) {
|
|
||||||
return x && y;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct LogicalOr {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x, T y) {
|
|
||||||
return x || y;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct BitwiseAnd {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x, T y) {
|
|
||||||
return x & y;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct BitwiseOr {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x, T y) {
|
|
||||||
return x | y;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct BitwiseXor {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x, T y) {
|
|
||||||
return x ^ y;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct LeftShift {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x, T y) {
|
|
||||||
return x << y;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct RightShift {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x, T y) {
|
|
||||||
return x >> y;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ArcTan2 {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T y, T x) {
|
|
||||||
return metal::precise::atan2(y, x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
@ -1,130 +1,24 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <metal_integer>
|
||||||
|
#include <metal_math>
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||||
#include "mlx/backend/metal/kernels/binary.h"
|
#include "mlx/backend/metal/kernels/binary.h"
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
|
||||||
[[kernel]] void binary_op_ss(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
c[index] = Op()(a[0], b[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
|
||||||
[[kernel]] void binary_op_sv(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
c[index] = Op()(a[0], b[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
|
||||||
[[kernel]] void binary_op_vs(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
c[index] = Op()(a[index], b[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
|
||||||
[[kernel]] void binary_op_vv(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
c[index] = Op()(a[index], b[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
|
||||||
[[kernel]] void binary_op_g_nd1(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
constant const size_t& a_stride,
|
|
||||||
constant const size_t& b_stride,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
|
||||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
|
||||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
|
||||||
[[kernel]] void binary_op_g_nd2(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
constant const size_t a_strides[2],
|
|
||||||
constant const size_t b_strides[2],
|
|
||||||
uint2 index [[thread_position_in_grid]],
|
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
|
||||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
|
||||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
|
||||||
[[kernel]] void binary_op_g_nd3(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
constant const size_t a_strides[3],
|
|
||||||
constant const size_t b_strides[3],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
|
||||||
size_t out_idx =
|
|
||||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
|
||||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int DIM>
|
|
||||||
[[kernel]] void binary_op_g_nd(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
constant const int shape[DIM],
|
|
||||||
constant const size_t a_strides[DIM],
|
|
||||||
constant const size_t b_strides[DIM],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
|
||||||
size_t out_idx =
|
|
||||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
|
||||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
|
||||||
[[kernel]] void binary_op_g(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
constant const int* shape,
|
|
||||||
constant const size_t* a_strides,
|
|
||||||
constant const size_t* b_strides,
|
|
||||||
constant const int& ndim,
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
|
||||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
|
||||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||||
template \
|
template \
|
||||||
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
[[host_name(name)]] [[kernel]] void binary_##bopt<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
template [[host_name("g" #dims name)]] [[kernel]] void \
|
||||||
binary_op_g_nd<itype, otype, op, dims>( \
|
binary_g_nd<itype, otype, op, dims>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -135,16 +29,16 @@ template <typename T, typename U, typename Op>
|
|||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||||
template [[host_name(name "_1")]] [[kernel]] void \
|
template [[host_name("g1" name)]] [[kernel]] void \
|
||||||
binary_op_g_nd1<itype, otype, op>( \
|
binary_g_nd1<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
constant const size_t& a_stride, \
|
constant const size_t& a_stride, \
|
||||||
constant const size_t& b_stride, \
|
constant const size_t& b_stride, \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] [[kernel]] void \
|
template [[host_name("g2" name)]] [[kernel]] void \
|
||||||
binary_op_g_nd2<itype, otype, op>( \
|
binary_g_nd2<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -152,8 +46,8 @@ template <typename T, typename U, typename Op>
|
|||||||
constant const size_t b_strides[2], \
|
constant const size_t b_strides[2], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] [[kernel]] void \
|
template [[host_name("g3" name)]] [[kernel]] void \
|
||||||
binary_op_g_nd3<itype, otype, op>( \
|
binary_g_nd3<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -165,7 +59,7 @@ template <typename T, typename U, typename Op>
|
|||||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||||
|
|
||||||
#define instantiate_binary_g(name, itype, otype, op) \
|
#define instantiate_binary_g(name, itype, otype, op) \
|
||||||
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
|
template [[host_name("gn" name)]] [[kernel]] void binary_g<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -176,16 +70,14 @@ template <typename T, typename U, typename Op>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
instantiate_binary_g(#name #tname, itype, otype, op) \
|
||||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
|
instantiate_binary_g_nd(#name #tname, itype, otype, op)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_binary_integer(name, op) \
|
#define instantiate_binary_integer(name, op) \
|
||||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||||
@ -194,22 +86,19 @@ template <typename T, typename U, typename Op>
|
|||||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
|
instantiate_binary_all(name, int64, int64_t, int64_t, op)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_binary_float(name, op) \
|
#define instantiate_binary_float(name, op) \
|
||||||
instantiate_binary_all(name, float16, half, half, op) \
|
instantiate_binary_all(name, float16, half, half, op) \
|
||||||
instantiate_binary_all(name, float32, float, float, op) \
|
instantiate_binary_all(name, float32, float, float, op) \
|
||||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
|
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_binary_types(name, op) \
|
#define instantiate_binary_types(name, op) \
|
||||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||||
instantiate_binary_integer(name, op) \
|
instantiate_binary_integer(name, op) \
|
||||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||||
instantiate_binary_float(name, op) // clang-format on
|
instantiate_binary_float(name, op)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_binary_types_bool(name, op) \
|
#define instantiate_binary_types_bool(name, op) \
|
||||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||||
@ -223,9 +112,8 @@ template <typename T, typename U, typename Op>
|
|||||||
instantiate_binary_all(name, float16, half, bool, op) \
|
instantiate_binary_all(name, float16, half, bool, op) \
|
||||||
instantiate_binary_all(name, float32, float, bool, op) \
|
instantiate_binary_all(name, float32, float, bool, op) \
|
||||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
||||||
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
|
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
instantiate_binary_types(add, Add)
|
instantiate_binary_types(add, Add)
|
||||||
instantiate_binary_types(div, Divide)
|
instantiate_binary_types(div, Divide)
|
||||||
instantiate_binary_types_bool(eq, Equal)
|
instantiate_binary_types_bool(eq, Equal)
|
||||||
|
296
mlx/backend/metal/kernels/binary_ops.h
Normal file
296
mlx/backend/metal/kernels/binary_ops.h
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_integer>
|
||||||
|
#include <metal_math>
|
||||||
|
|
||||||
|
struct Add {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x + y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FloorDivide {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x / y;
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
float operator()(float x, float y) {
|
||||||
|
return trunc(x / y);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
half operator()(half x, half y) {
|
||||||
|
return trunc(x / y);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
||||||
|
return trunc(x / y);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Divide {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x / y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Remainder {
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||||
|
operator()(T x, T y) {
|
||||||
|
return x % y;
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||||
|
operator()(T x, T y) {
|
||||||
|
auto r = x % y;
|
||||||
|
if (r != 0 && (r < 0 != y < 0)) {
|
||||||
|
r += y;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||||
|
T r = fmod(x, y);
|
||||||
|
if (r != 0 && (r < 0 != y < 0)) {
|
||||||
|
r += y;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
|
return x % y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Equal {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x == y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NaNEqual {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
bool operator()(complex64_t x, complex64_t y) {
|
||||||
|
return x == y ||
|
||||||
|
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
|
||||||
|
metal::isnan(y.imag)) ||
|
||||||
|
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||||
|
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Greater {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x > y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GreaterEqual {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x >= y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Less {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x < y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LessEqual {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x <= y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogAddExp {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
if (metal::isnan(x) || metal::isnan(y)) {
|
||||||
|
return metal::numeric_limits<T>::quiet_NaN();
|
||||||
|
}
|
||||||
|
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||||
|
T maxval = metal::max(x, y);
|
||||||
|
T minval = metal::min(x, y);
|
||||||
|
return (minval == -inf || maxval == inf)
|
||||||
|
? maxval
|
||||||
|
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Maximum {
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||||
|
return metal::max(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||||
|
if (metal::isnan(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return x > y ? x : y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
|
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return x > y ? x : y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Minimum {
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||||
|
return metal::min(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||||
|
if (metal::isnan(x)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return x < y ? x : y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
|
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
return x < y ? x : y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Multiply {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x * y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NotEqual {
|
||||||
|
template <typename T>
|
||||||
|
bool operator()(T x, T y) {
|
||||||
|
return x != y;
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
bool operator()(complex64_t x, complex64_t y) {
|
||||||
|
return x.real != y.real || x.imag != y.imag;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Power {
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||||
|
return metal::pow(base, exp);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||||
|
T res = 1;
|
||||||
|
while (exp) {
|
||||||
|
if (exp & 1) {
|
||||||
|
res *= base;
|
||||||
|
}
|
||||||
|
exp >>= 1;
|
||||||
|
base *= base;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||||
|
auto x_theta = metal::atan(x.imag / x.real);
|
||||||
|
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||||
|
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||||
|
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||||
|
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Subtract {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x - y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogicalAnd {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x && y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogicalOr {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x || y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BitwiseAnd {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x & y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BitwiseOr {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x | y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BitwiseXor {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x ^ y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LeftShift {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x << y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RightShift {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x, T y) {
|
||||||
|
return x >> y;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcTan2 {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T y, T x) {
|
||||||
|
return metal::precise::atan2(y, x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DivMod {
|
||||||
|
template <typename T>
|
||||||
|
metal::array<T, 2> operator()(T x, T y) {
|
||||||
|
return {FloorDivide{}(x, y), Remainder{}(x, y)};
|
||||||
|
};
|
||||||
|
};
|
140
mlx/backend/metal/kernels/binary_two.h
Normal file
140
mlx/backend/metal/kernels/binary_two.h
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_ss(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
device U* d,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto out = Op()(a[0], b[0]);
|
||||||
|
c[index] = out[0];
|
||||||
|
d[index] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_sv(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
device U* d,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto out = Op()(a[0], b[index]);
|
||||||
|
c[index] = out[0];
|
||||||
|
d[index] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_vs(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
device U* d,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto out = Op()(a[index], b[0]);
|
||||||
|
c[index] = out[0];
|
||||||
|
d[index] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_vv(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
device U* d,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto out = Op()(a[index], b[index]);
|
||||||
|
c[index] = out[0];
|
||||||
|
d[index] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_g_nd1(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
device U* d,
|
||||||
|
constant const size_t& a_stride,
|
||||||
|
constant const size_t& b_stride,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||||
|
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||||
|
auto out = Op()(a[a_idx], b[b_idx]);
|
||||||
|
c[index] = out[0];
|
||||||
|
d[index] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_g_nd2(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
device U* d,
|
||||||
|
constant const size_t a_strides[2],
|
||||||
|
constant const size_t b_strides[2],
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||||
|
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||||
|
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||||
|
auto out = Op()(a[a_idx], b[b_idx]);
|
||||||
|
c[out_idx] = out[0];
|
||||||
|
d[out_idx] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_g_nd3(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
device U* d,
|
||||||
|
constant const size_t a_strides[3],
|
||||||
|
constant const size_t b_strides[3],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||||
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
|
auto out = Op()(a[a_idx], b[b_idx]);
|
||||||
|
c[out_idx] = out[0];
|
||||||
|
d[out_idx] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op, int DIM>
|
||||||
|
[[kernel]] void binary_g_nd(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
device U* d,
|
||||||
|
constant const int shape[DIM],
|
||||||
|
constant const size_t a_strides[DIM],
|
||||||
|
constant const size_t b_strides[DIM],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||||
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
|
auto out = Op()(a[idx.x], b[idx.y]);
|
||||||
|
c[out_idx] = out[0];
|
||||||
|
d[out_idx] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
[[kernel]] void binary_g(
|
||||||
|
device const T* a,
|
||||||
|
device const T* b,
|
||||||
|
device U* c,
|
||||||
|
device U* d,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* a_strides,
|
||||||
|
constant const size_t* b_strides,
|
||||||
|
constant const int& ndim,
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||||
|
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||||
|
auto out = Op()(a[idx.x], b[idx.y]);
|
||||||
|
c[out_idx] = out[0];
|
||||||
|
d[out_idx] = out[1];
|
||||||
|
}
|
@ -1,212 +1,24 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include <metal_integer>
|
#include <metal_integer>
|
||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||||
|
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||||
|
|
||||||
struct FloorDivide {
|
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||||
template <typename T>
|
|
||||||
T operator()(T x, T y) {
|
|
||||||
return x / y;
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
float operator()(float x, float y) {
|
|
||||||
return trunc(x / y);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
half operator()(half x, half y) {
|
|
||||||
return trunc(x / y);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
|
||||||
return trunc(x / y);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Remainder {
|
|
||||||
template <typename T>
|
|
||||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
|
||||||
operator()(T x, T y) {
|
|
||||||
return x % y;
|
|
||||||
}
|
|
||||||
template <typename T>
|
|
||||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
|
||||||
operator()(T x, T y) {
|
|
||||||
auto r = x % y;
|
|
||||||
if (r != 0 && (r < 0 != y < 0)) {
|
|
||||||
r += y;
|
|
||||||
}
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
template <typename T>
|
|
||||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
|
||||||
T r = fmod(x, y);
|
|
||||||
if (r != 0 && (r < 0 != y < 0)) {
|
|
||||||
r += y;
|
|
||||||
}
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
|
||||||
return x % y;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2>
|
|
||||||
[[kernel]] void binary_op_s2s(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
device U* d,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
c[index] = Op1()(a[0], b[0]);
|
|
||||||
d[index] = Op2()(a[0], b[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2>
|
|
||||||
[[kernel]] void binary_op_ss(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
device U* d,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
c[index] = Op1()(a[0], b[0]);
|
|
||||||
d[index] = Op2()(a[0], b[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2>
|
|
||||||
[[kernel]] void binary_op_sv(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
device U* d,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
c[index] = Op1()(a[0], b[index]);
|
|
||||||
d[index] = Op2()(a[0], b[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2>
|
|
||||||
[[kernel]] void binary_op_vs(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
device U* d,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
c[index] = Op1()(a[index], b[0]);
|
|
||||||
d[index] = Op2()(a[index], b[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2>
|
|
||||||
[[kernel]] void binary_op_vv(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
device U* d,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
c[index] = Op1()(a[index], b[index]);
|
|
||||||
d[index] = Op2()(a[index], b[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2>
|
|
||||||
[[kernel]] void binary_op_g_nd1(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
device U* d,
|
|
||||||
constant const size_t& a_stride,
|
|
||||||
constant const size_t& b_stride,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
|
||||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
|
||||||
c[index] = Op1()(a[a_idx], b[b_idx]);
|
|
||||||
d[index] = Op2()(a[a_idx], b[b_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2>
|
|
||||||
[[kernel]] void binary_op_g_nd2(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
device U* d,
|
|
||||||
constant const size_t a_strides[2],
|
|
||||||
constant const size_t b_strides[2],
|
|
||||||
uint2 index [[thread_position_in_grid]],
|
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
|
||||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
|
||||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
|
||||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2>
|
|
||||||
[[kernel]] void binary_op_g_nd3(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
device U* d,
|
|
||||||
constant const size_t a_strides[3],
|
|
||||||
constant const size_t b_strides[3],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
|
||||||
size_t out_idx =
|
|
||||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
|
||||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
|
||||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2, int DIM>
|
|
||||||
[[kernel]] void binary_op_g_nd(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
device U* d,
|
|
||||||
constant const int shape[DIM],
|
|
||||||
constant const size_t a_strides[DIM],
|
|
||||||
constant const size_t b_strides[DIM],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
|
||||||
size_t out_idx =
|
|
||||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
|
||||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
|
||||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename Op1, typename Op2>
|
|
||||||
[[kernel]] void binary_op_g(
|
|
||||||
device const T* a,
|
|
||||||
device const T* b,
|
|
||||||
device U* c,
|
|
||||||
device U* d,
|
|
||||||
constant const int* shape,
|
|
||||||
constant const size_t* a_strides,
|
|
||||||
constant const size_t* b_strides,
|
|
||||||
constant const int& ndim,
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
|
||||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
|
||||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
|
||||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
|
|
||||||
template [[host_name(name)]] [[kernel]] void \
|
template [[host_name(name)]] [[kernel]] void \
|
||||||
binary_op_##bopt<itype, otype, op1, op2>( \
|
binary_##bopt<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
device otype* d, \
|
device otype* d, \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
|
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
template [[host_name("g" #dims name)]] [[kernel]] void \
|
||||||
binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
binary_g_nd<itype, otype, op, dims>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -217,10 +29,9 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
// clang-format off
|
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||||
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
|
template [[host_name("g1" name)]] [[kernel]] void \
|
||||||
template [[host_name(name "_1")]] [[kernel]] void \
|
binary_g_nd1<itype, otype, op>( \
|
||||||
binary_op_g_nd1<itype, otype, op1, op2>( \
|
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -228,8 +39,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
constant const size_t& a_stride, \
|
constant const size_t& a_stride, \
|
||||||
constant const size_t& b_stride, \
|
constant const size_t& b_stride, \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] [[kernel]] void \
|
template [[host_name("g2" name)]] [[kernel]] void \
|
||||||
binary_op_g_nd2<itype, otype, op1, op2>( \
|
binary_g_nd2<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -238,8 +49,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
constant const size_t b_strides[2], \
|
constant const size_t b_strides[2], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] [[kernel]] void \
|
template [[host_name("g3" name)]] [[kernel]] void \
|
||||||
binary_op_g_nd3<itype, otype, op1, op2>( \
|
binary_g_nd3<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -248,12 +59,12 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
constant const size_t b_strides[3], \
|
constant const size_t b_strides[3], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
|
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||||
|
|
||||||
#define instantiate_binary_g(name, itype, otype, op1, op2) \
|
#define instantiate_binary_g(name, itype, otype, op) \
|
||||||
template [[host_name(name)]] [[kernel]] void \
|
template [[host_name("gn" name)]] [[kernel]] void \
|
||||||
binary_op_g<itype, otype, op2, op2>( \
|
binary_g<itype, otype, op>( \
|
||||||
device const itype* a, \
|
device const itype* a, \
|
||||||
device const itype* b, \
|
device const itype* b, \
|
||||||
device otype* c, \
|
device otype* c, \
|
||||||
@ -265,33 +76,30 @@ template <typename T, typename U, typename Op1, typename Op2>
|
|||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
// clang-format off
|
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||||
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
|
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||||
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
|
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||||
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
|
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||||
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
|
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||||
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
|
instantiate_binary_g(#name #tname, itype, otype, op) \
|
||||||
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
instantiate_binary_g_nd(#name #tname, itype, otype, op)
|
||||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
|
|
||||||
|
|
||||||
// clang-format off
|
#define instantiate_binary_float(name, op) \
|
||||||
#define instantiate_binary_float(name, op1, op2) \
|
instantiate_binary_all(name, float16, half, half, op) \
|
||||||
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
instantiate_binary_all(name, float32, float, float, op) \
|
||||||
instantiate_binary_all(name, float32, float, float, op1, op2) \
|
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
|
|
||||||
|
|
||||||
// clang-format off
|
#define instantiate_binary_types(name, op) \
|
||||||
#define instantiate_binary_types(name, op1, op2) \
|
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||||
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
|
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
|
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
|
||||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
|
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||||
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
|
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||||
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
|
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||||
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
|
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
||||||
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
|
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
|
instantiate_binary_float(name, op)
|
||||||
instantiate_binary_float(name, op1, op2)
|
|
||||||
|
|
||||||
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on
|
instantiate_binary_types(divmod, DivMod) // clang-format on
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/binary.h"
|
|
||||||
#include "mlx/backend/metal/kernels/ternary.h"
|
|
||||||
#include "mlx/backend/metal/kernels/unary.h"
|
|
||||||
|
|
||||||
typedef half float16_t;
|
|
144
mlx/backend/metal/kernels/copy.h
Normal file
144
mlx/backend/metal/kernels/copy.h
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_s(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
dst[index] = static_cast<U>(src[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_v(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
dst[index] = static_cast<U>(src[index]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_g_nd1(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
constant const int64_t& src_stride [[buffer(3)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||||
|
dst[index] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_g_nd2(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||||
|
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_g_nd3(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||||
|
int64_t dst_idx =
|
||||||
|
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, int DIM>
|
||||||
|
[[kernel]] void copy_g_nd(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
constant const int* src_shape [[buffer(2)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||||
|
int64_t dst_idx =
|
||||||
|
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_g(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
constant const int* src_shape [[buffer(2)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
constant const int& ndim [[buffer(5)]],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||||
|
int64_t dst_idx =
|
||||||
|
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_gg_nd1(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
constant const int64_t& src_stride [[buffer(3)]],
|
||||||
|
constant const int64_t& dst_stride [[buffer(4)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||||
|
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_gg_nd2(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
|
uint2 index [[thread_position_in_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||||
|
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_gg_nd3(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
|
uint3 index [[thread_position_in_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||||
|
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, int DIM>
|
||||||
|
[[kernel]] void copy_gg_nd(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
constant const int* src_shape [[buffer(2)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
|
uint3 index [[thread_position_in_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||||
|
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
[[kernel]] void copy_gg(
|
||||||
|
device const T* src [[buffer(0)]],
|
||||||
|
device U* dst [[buffer(1)]],
|
||||||
|
constant const int* src_shape [[buffer(2)]],
|
||||||
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
|
constant const int& ndim [[buffer(5)]],
|
||||||
|
uint3 index [[thread_position_in_grid]]) {
|
||||||
|
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||||
|
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
||||||
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
|
}
|
@ -1,150 +1,9 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
template <typename T, typename U>
|
#include "mlx/backend/metal/kernels/copy.h"
|
||||||
[[kernel]] void copy_s(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
dst[index] = static_cast<U>(src[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
|
||||||
[[kernel]] void copy_v(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
dst[index] = static_cast<U>(src[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
|
||||||
[[kernel]] void copy_g_nd1(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
constant const int64_t& src_stride [[buffer(3)]],
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
|
||||||
dst[index] = static_cast<U>(src[src_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
|
||||||
[[kernel]] void copy_g_nd2(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
uint2 index [[thread_position_in_grid]],
|
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
|
||||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
|
||||||
[[kernel]] void copy_g_nd3(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
|
||||||
int64_t dst_idx =
|
|
||||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, int DIM>
|
|
||||||
[[kernel]] void copy_g_nd(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
constant const int* src_shape [[buffer(2)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
|
||||||
int64_t dst_idx =
|
|
||||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
|
||||||
[[kernel]] void copy_g(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
constant const int* src_shape [[buffer(2)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
constant const int& ndim [[buffer(5)]],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
|
||||||
int64_t dst_idx =
|
|
||||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
|
||||||
[[kernel]] void copy_gg_nd1(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
constant const int64_t& src_stride [[buffer(3)]],
|
|
||||||
constant const int64_t& dst_stride [[buffer(4)]],
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
|
||||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
|
||||||
[[kernel]] void copy_gg_nd2(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
|
||||||
uint2 index [[thread_position_in_grid]]) {
|
|
||||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
|
||||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
|
||||||
[[kernel]] void copy_gg_nd3(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
|
||||||
uint3 index [[thread_position_in_grid]]) {
|
|
||||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
|
||||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, int DIM>
|
|
||||||
[[kernel]] void copy_gg_nd(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
constant const int* src_shape [[buffer(2)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
|
||||||
uint3 index [[thread_position_in_grid]]) {
|
|
||||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
|
||||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
|
||||||
[[kernel]] void copy_gg(
|
|
||||||
device const T* src [[buffer(0)]],
|
|
||||||
device U* dst [[buffer(1)]],
|
|
||||||
constant const int* src_shape [[buffer(2)]],
|
|
||||||
constant const int64_t* src_strides [[buffer(3)]],
|
|
||||||
constant const int64_t* dst_strides [[buffer(4)]],
|
|
||||||
constant const int& ndim [[buffer(5)]],
|
|
||||||
uint3 index [[thread_position_in_grid]]) {
|
|
||||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
|
||||||
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_copy(name, itype, otype, ctype) \
|
#define instantiate_copy(name, itype, otype, ctype) \
|
||||||
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
|
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
|
||||||
@ -153,7 +12,7 @@ template <typename T, typename U>
|
|||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
template [[host_name("g" #dims "_" name)]] [[kernel]] void \
|
||||||
copy_g_nd<itype, otype, dims>( \
|
copy_g_nd<itype, otype, dims>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
@ -161,7 +20,7 @@ template <typename T, typename U>
|
|||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name("g" name "_" #dims)]] [[kernel]] void \
|
template [[host_name("gg" #dims "_" name)]] [[kernel]] void \
|
||||||
copy_gg_nd<itype, otype, dims>( \
|
copy_gg_nd<itype, otype, dims>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
@ -171,38 +30,38 @@ template <typename T, typename U>
|
|||||||
uint3 index [[thread_position_in_grid]]);
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||||
template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \
|
template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t& src_stride [[buffer(3)]], \
|
constant const int64_t& src_stride [[buffer(3)]], \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \
|
template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \
|
template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name("g" name "_1")]] [[kernel]] void \
|
template [[host_name("gg1_" name )]] [[kernel]] void \
|
||||||
copy_gg_nd1<itype, otype>( \
|
copy_gg_nd1<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t& src_stride [[buffer(3)]], \
|
constant const int64_t& src_stride [[buffer(3)]], \
|
||||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name("g" name "_2")]] [[kernel]] void \
|
template [[host_name("gg2_" name)]] [[kernel]] void \
|
||||||
copy_gg_nd2<itype, otype>( \
|
copy_gg_nd2<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int64_t* src_strides [[buffer(3)]], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||||
uint2 index [[thread_position_in_grid]]); \
|
uint2 index [[thread_position_in_grid]]); \
|
||||||
template [[host_name("g" name "_3")]] [[kernel]] void \
|
template [[host_name("gg3_" name)]] [[kernel]] void \
|
||||||
copy_gg_nd3<itype, otype>( \
|
copy_gg_nd3<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
@ -213,7 +72,7 @@ template <typename T, typename U>
|
|||||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||||
|
|
||||||
#define instantiate_copy_g(name, itype, otype) \
|
#define instantiate_copy_g(name, itype, otype) \
|
||||||
template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \
|
template [[host_name("g_" name)]] [[kernel]] void copy_g<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
@ -221,7 +80,7 @@ template <typename T, typename U>
|
|||||||
constant const int& ndim [[buffer(5)]], \
|
constant const int& ndim [[buffer(5)]], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \
|
template [[host_name("gg_" name)]] [[kernel]] void copy_gg<itype, otype>( \
|
||||||
device const itype* src [[buffer(0)]], \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst [[buffer(1)]], \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int* src_shape [[buffer(2)]], \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
@ -230,14 +89,12 @@ template <typename T, typename U>
|
|||||||
constant const int& ndim [[buffer(5)]], \
|
constant const int& ndim [[buffer(5)]], \
|
||||||
uint3 index [[thread_position_in_grid]]);
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_copy_all(tname, itype, otype) \
|
#define instantiate_copy_all(tname, itype, otype) \
|
||||||
instantiate_copy("scopy" #tname, itype, otype, s) \
|
instantiate_copy("s_copy" #tname, itype, otype, s) \
|
||||||
instantiate_copy("vcopy" #tname, itype, otype, v) \
|
instantiate_copy("v_copy" #tname, itype, otype, v) \
|
||||||
instantiate_copy_g("gcopy" #tname, itype, otype) \
|
instantiate_copy_g("copy" #tname, itype, otype) \
|
||||||
instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on
|
instantiate_copy_g_nd("copy" #tname, itype, otype)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_copy_itype(itname, itype) \
|
#define instantiate_copy_itype(itname, itype) \
|
||||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||||
|
@ -8,8 +8,6 @@
|
|||||||
#define MTL_CONST
|
#define MTL_CONST
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
|
||||||
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
|
||||||
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
45
mlx/backend/metal/kernels/gather.h
Normal file
45
mlx/backend/metal/kernels/gather.h
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/indexing.h"
|
||||||
|
|
||||||
|
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||||
|
METAL_FUNC void gather_impl(
|
||||||
|
const device T* src [[buffer(0)]],
|
||||||
|
device T* out [[buffer(1)]],
|
||||||
|
const constant int* src_shape [[buffer(2)]],
|
||||||
|
const constant size_t* src_strides [[buffer(3)]],
|
||||||
|
const constant size_t& src_ndim [[buffer(4)]],
|
||||||
|
const constant int* slice_sizes [[buffer(5)]],
|
||||||
|
const constant int* axes [[buffer(6)]],
|
||||||
|
const thread Indices<IdxT, NIDX>& indices,
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto ind_idx = index.x;
|
||||||
|
auto ind_offset = index.y;
|
||||||
|
|
||||||
|
size_t src_idx = 0;
|
||||||
|
for (int i = 0; i < NIDX; ++i) {
|
||||||
|
size_t idx_loc;
|
||||||
|
if (IDX_NDIM == 0) {
|
||||||
|
idx_loc = 0;
|
||||||
|
} else if (IDX_NDIM == 1) {
|
||||||
|
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||||
|
} else {
|
||||||
|
idx_loc = elem_to_loc(
|
||||||
|
ind_idx,
|
||||||
|
&indices.shapes[indices.ndim * i],
|
||||||
|
&indices.strides[indices.ndim * i],
|
||||||
|
indices.ndim);
|
||||||
|
}
|
||||||
|
auto ax = axes[i];
|
||||||
|
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||||
|
src_idx += idx_val * src_strides[ax];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
|
||||||
|
|
||||||
|
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||||
|
out[out_idx] = src[src_offset + src_idx];
|
||||||
|
}
|
@ -1,173 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include <metal_atomic>
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/indexing.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
|
||||||
// Gather kernel
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
|
||||||
METAL_FUNC void gather_impl(
|
|
||||||
const device T* src [[buffer(0)]],
|
|
||||||
device T* out [[buffer(1)]],
|
|
||||||
const constant int* src_shape [[buffer(2)]],
|
|
||||||
const constant size_t* src_strides [[buffer(3)]],
|
|
||||||
const constant size_t& src_ndim [[buffer(4)]],
|
|
||||||
const constant int* slice_sizes [[buffer(5)]],
|
|
||||||
const constant int* axes [[buffer(6)]],
|
|
||||||
const thread Indices<IdxT, NIDX>& indices,
|
|
||||||
uint2 index [[thread_position_in_grid]],
|
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto ind_idx = index.x;
|
|
||||||
auto ind_offset = index.y;
|
|
||||||
|
|
||||||
size_t src_idx = 0;
|
|
||||||
for (int i = 0; i < NIDX; ++i) {
|
|
||||||
size_t idx_loc;
|
|
||||||
if (IDX_NDIM == 0) {
|
|
||||||
idx_loc = 0;
|
|
||||||
} else if (IDX_NDIM == 1) {
|
|
||||||
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
|
||||||
} else {
|
|
||||||
idx_loc = elem_to_loc(
|
|
||||||
ind_idx,
|
|
||||||
&indices.shapes[indices.ndim * i],
|
|
||||||
&indices.strides[indices.ndim * i],
|
|
||||||
indices.ndim);
|
|
||||||
}
|
|
||||||
auto ax = axes[i];
|
|
||||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
|
||||||
src_idx += idx_val * src_strides[ax];
|
|
||||||
}
|
|
||||||
|
|
||||||
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
|
|
||||||
|
|
||||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
|
||||||
out[out_idx] = src[src_offset + src_idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
#define make_gather_impl(IDX_ARG, IDX_ARR) \
|
|
||||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
|
||||||
[[kernel]] void gather( \
|
|
||||||
const device T* src [[buffer(0)]], \
|
|
||||||
device T* out [[buffer(1)]], \
|
|
||||||
const constant int* src_shape [[buffer(2)]], \
|
|
||||||
const constant size_t* src_strides [[buffer(3)]], \
|
|
||||||
const constant size_t& src_ndim [[buffer(4)]], \
|
|
||||||
const constant int* slice_sizes [[buffer(5)]], \
|
|
||||||
const constant int* axes [[buffer(6)]], \
|
|
||||||
const constant int* idx_shapes [[buffer(7)]], \
|
|
||||||
const constant size_t* idx_strides [[buffer(8)]], \
|
|
||||||
const constant int& idx_ndim [[buffer(9)]], \
|
|
||||||
IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \
|
|
||||||
uint2 grid_dim [[threads_per_grid]]) { \
|
|
||||||
Indices<IdxT, NIDX> idxs{ \
|
|
||||||
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
|
|
||||||
\
|
|
||||||
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
|
|
||||||
src, \
|
|
||||||
out, \
|
|
||||||
src_shape, \
|
|
||||||
src_strides, \
|
|
||||||
src_ndim, \
|
|
||||||
slice_sizes, \
|
|
||||||
axes, \
|
|
||||||
idxs, \
|
|
||||||
index, \
|
|
||||||
grid_dim); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
|
|
||||||
|
|
||||||
make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4)
|
|
||||||
make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9)
|
|
||||||
make_gather(10)
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
|
||||||
// Gather instantiations
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
|
|
||||||
template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \
|
|
||||||
gather<src_t, idx_t, nidx, nd>( \
|
|
||||||
const device src_t* src [[buffer(0)]], \
|
|
||||||
device src_t* out [[buffer(1)]], \
|
|
||||||
const constant int* src_shape [[buffer(2)]], \
|
|
||||||
const constant size_t* src_strides [[buffer(3)]], \
|
|
||||||
const constant size_t& src_ndim [[buffer(4)]], \
|
|
||||||
const constant int* slice_sizes [[buffer(5)]], \
|
|
||||||
const constant int* axes [[buffer(6)]], \
|
|
||||||
const constant int* idx_shapes [[buffer(7)]], \
|
|
||||||
const constant size_t* idx_strides [[buffer(8)]], \
|
|
||||||
const constant int& idx_ndim [[buffer(9)]], \
|
|
||||||
IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \
|
|
||||||
uint2 grid_dim [[threads_per_grid]]);
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
|
|
||||||
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gather4(name, src_t, idx_t, nidx) \
|
|
||||||
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
|
|
||||||
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
|
|
||||||
instantiate_gather5(name, src_t, idx_t, nidx, 2, )
|
|
||||||
|
|
||||||
|
|
||||||
// Special for case NIDX=0
|
|
||||||
instantiate_gather4("bool_", bool, bool, 0)
|
|
||||||
instantiate_gather4("uint8", uint8_t, bool, 0)
|
|
||||||
instantiate_gather4("uint16", uint16_t, bool, 0)
|
|
||||||
instantiate_gather4("uint32", uint32_t, bool, 0)
|
|
||||||
instantiate_gather4("uint64", uint64_t, bool, 0)
|
|
||||||
instantiate_gather4("int8", int8_t, bool, 0)
|
|
||||||
instantiate_gather4("int16", int16_t, bool, 0)
|
|
||||||
instantiate_gather4("int32", int32_t, bool, 0)
|
|
||||||
instantiate_gather4("int64", int64_t, bool, 0)
|
|
||||||
instantiate_gather4("float16", half, bool, 0)
|
|
||||||
instantiate_gather4("float32", float, bool, 0)
|
|
||||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gather3(name, src_type, ind_type) \
|
|
||||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
|
||||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
|
||||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
|
||||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
|
||||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
|
||||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
|
||||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
|
||||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
|
||||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
|
||||||
instantiate_gather4(name, src_type, ind_type, 10) // clang-format on
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gather(name, src_type) \
|
|
||||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
|
||||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
|
||||||
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
|
||||||
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
|
||||||
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
|
||||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
|
||||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
|
||||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
|
||||||
instantiate_gather3(#name "int64", src_type, int64_t)
|
|
||||||
|
|
||||||
instantiate_gather(bool_, bool)
|
|
||||||
instantiate_gather(uint8, uint8_t)
|
|
||||||
instantiate_gather(uint16, uint16_t)
|
|
||||||
instantiate_gather(uint32, uint32_t)
|
|
||||||
instantiate_gather(uint64, uint64_t)
|
|
||||||
instantiate_gather(int8, int8_t)
|
|
||||||
instantiate_gather(int16, int16_t)
|
|
||||||
instantiate_gather(int32, int32_t)
|
|
||||||
instantiate_gather(int64, int64_t)
|
|
||||||
instantiate_gather(float16, half)
|
|
||||||
instantiate_gather(float32, float)
|
|
||||||
instantiate_gather(bfloat16, bfloat16_t) // clang-format on
|
|
@ -1,13 +1,9 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
|
||||||
// Indexing utils
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename IdxT, int NIDX>
|
template <typename IdxT, int NIDX>
|
||||||
struct Indices {
|
struct Indices {
|
||||||
const array<const device IdxT*, NIDX> buffers;
|
const array<const device IdxT*, NIDX> buffers;
|
||||||
@ -24,31 +20,3 @@ METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
|
|||||||
return (idx < 0) ? idx + size : idx;
|
return (idx < 0) ? idx + size : idx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
|
|
||||||
|
|
||||||
#define IDX_ARG_0(idx_t)
|
|
||||||
#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
|
|
||||||
#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
|
|
||||||
#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
|
|
||||||
#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
|
|
||||||
#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
|
|
||||||
#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
|
|
||||||
#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
|
|
||||||
#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
|
|
||||||
#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
|
|
||||||
#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
|
|
||||||
|
|
||||||
#define IDX_ARR_N(n) idx##n,
|
|
||||||
|
|
||||||
#define IDX_ARR_0()
|
|
||||||
#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
|
|
||||||
#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
|
|
||||||
#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
|
|
||||||
#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
|
|
||||||
#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
|
|
||||||
#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
|
|
||||||
#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
|
|
||||||
#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
|
|
||||||
#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
|
|
||||||
#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)
|
|
6
mlx/backend/metal/kernels/reduction.h
Normal file
6
mlx/backend/metal/kernels/reduction.h
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/atomic.h"
|
||||||
|
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
@ -5,9 +5,11 @@
|
|||||||
#include <metal_atomic>
|
#include <metal_atomic>
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
|
#ifndef MLX_METAL_JIT
|
||||||
#include "mlx/backend/metal/kernels/atomic.h"
|
#include "mlx/backend/metal/kernels/atomic.h"
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
union bool4_or_uint {
|
union bool4_or_uint {
|
||||||
bool4 b;
|
bool4 b;
|
||||||
|
66
mlx/backend/metal/kernels/scatter.h
Normal file
66
mlx/backend/metal/kernels/scatter.h
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/indexing.h"
|
||||||
|
|
||||||
|
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||||
|
METAL_FUNC void scatter_1d_index_impl(
|
||||||
|
const device T* updates [[buffer(1)]],
|
||||||
|
device mlx_atomic<T>* out [[buffer(2)]],
|
||||||
|
const constant int* out_shape [[buffer(3)]],
|
||||||
|
const constant size_t* out_strides [[buffer(4)]],
|
||||||
|
const constant size_t& upd_size [[buffer(5)]],
|
||||||
|
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||||
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
Op op;
|
||||||
|
|
||||||
|
uint out_idx = 0;
|
||||||
|
for (int i = 0; i < NIDX; i++) {
|
||||||
|
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
|
||||||
|
out_idx += idx_val * out_strides[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||||
|
METAL_FUNC void scatter_impl(
|
||||||
|
const device T* updates [[buffer(1)]],
|
||||||
|
device mlx_atomic<T>* out [[buffer(2)]],
|
||||||
|
const constant int* upd_shape [[buffer(3)]],
|
||||||
|
const constant size_t* upd_strides [[buffer(4)]],
|
||||||
|
const constant size_t& upd_ndim [[buffer(5)]],
|
||||||
|
const constant size_t& upd_size [[buffer(6)]],
|
||||||
|
const constant int* out_shape [[buffer(7)]],
|
||||||
|
const constant size_t* out_strides [[buffer(8)]],
|
||||||
|
const constant size_t& out_ndim [[buffer(9)]],
|
||||||
|
const constant int* axes [[buffer(10)]],
|
||||||
|
const thread Indices<IdxT, NIDX>& indices,
|
||||||
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
Op op;
|
||||||
|
auto ind_idx = gid.y;
|
||||||
|
auto ind_offset = gid.x;
|
||||||
|
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (int i = 0; i < NIDX; ++i) {
|
||||||
|
auto idx_loc = elem_to_loc(
|
||||||
|
ind_idx,
|
||||||
|
&indices.shapes[indices.ndim * i],
|
||||||
|
&indices.strides[indices.ndim * i],
|
||||||
|
indices.ndim);
|
||||||
|
auto ax = axes[i];
|
||||||
|
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
||||||
|
out_idx += idx_val * out_strides[ax];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (upd_size > 1) {
|
||||||
|
auto out_offset = elem_to_loc(
|
||||||
|
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||||
|
out_idx += out_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto upd_idx =
|
||||||
|
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||||
|
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||||
|
}
|
@ -1,236 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include <metal_atomic>
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/indexing.h"
|
|
||||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
|
||||||
// Scatter kernel
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
|
||||||
METAL_FUNC void scatter_1d_index_impl(
|
|
||||||
const device T* updates [[buffer(1)]],
|
|
||||||
device mlx_atomic<T>* out [[buffer(2)]],
|
|
||||||
const constant int* out_shape [[buffer(3)]],
|
|
||||||
const constant size_t* out_strides [[buffer(4)]],
|
|
||||||
const constant size_t& upd_size [[buffer(5)]],
|
|
||||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
|
||||||
Op op;
|
|
||||||
|
|
||||||
uint out_idx = 0;
|
|
||||||
for (int i = 0; i < NIDX; i++) {
|
|
||||||
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
|
|
||||||
out_idx += idx_val * out_strides[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
|
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
|
||||||
[[kernel]] void scatter_1d_index( \
|
|
||||||
const device T* updates [[buffer(1)]], \
|
|
||||||
device mlx_atomic<T>* out [[buffer(2)]], \
|
|
||||||
const constant int* out_shape [[buffer(3)]], \
|
|
||||||
const constant size_t* out_strides [[buffer(4)]], \
|
|
||||||
const constant size_t& upd_size [[buffer(5)]], \
|
|
||||||
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
|
|
||||||
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
|
|
||||||
\
|
|
||||||
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
|
|
||||||
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); \
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
|
||||||
METAL_FUNC void scatter_impl(
|
|
||||||
const device T* updates [[buffer(1)]],
|
|
||||||
device mlx_atomic<T>* out [[buffer(2)]],
|
|
||||||
const constant int* upd_shape [[buffer(3)]],
|
|
||||||
const constant size_t* upd_strides [[buffer(4)]],
|
|
||||||
const constant size_t& upd_ndim [[buffer(5)]],
|
|
||||||
const constant size_t& upd_size [[buffer(6)]],
|
|
||||||
const constant int* out_shape [[buffer(7)]],
|
|
||||||
const constant size_t* out_strides [[buffer(8)]],
|
|
||||||
const constant size_t& out_ndim [[buffer(9)]],
|
|
||||||
const constant int* axes [[buffer(10)]],
|
|
||||||
const thread Indices<IdxT, NIDX>& indices,
|
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
|
||||||
Op op;
|
|
||||||
auto ind_idx = gid.y;
|
|
||||||
auto ind_offset = gid.x;
|
|
||||||
|
|
||||||
size_t out_idx = 0;
|
|
||||||
for (int i = 0; i < NIDX; ++i) {
|
|
||||||
auto idx_loc = elem_to_loc(
|
|
||||||
ind_idx,
|
|
||||||
&indices.shapes[indices.ndim * i],
|
|
||||||
&indices.strides[indices.ndim * i],
|
|
||||||
indices.ndim);
|
|
||||||
auto ax = axes[i];
|
|
||||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
|
||||||
out_idx += idx_val * out_strides[ax];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (upd_size > 1) {
|
|
||||||
auto out_offset = elem_to_loc(
|
|
||||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
|
||||||
out_idx += out_offset;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto upd_idx =
|
|
||||||
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
|
||||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
|
|
||||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
|
||||||
[[kernel]] void scatter( \
|
|
||||||
const device T* updates [[buffer(1)]], \
|
|
||||||
device mlx_atomic<T>* out [[buffer(2)]], \
|
|
||||||
const constant int* upd_shape [[buffer(3)]], \
|
|
||||||
const constant size_t* upd_strides [[buffer(4)]], \
|
|
||||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
|
||||||
const constant size_t& upd_size [[buffer(6)]], \
|
|
||||||
const constant int* out_shape [[buffer(7)]], \
|
|
||||||
const constant size_t* out_strides [[buffer(8)]], \
|
|
||||||
const constant size_t& out_ndim [[buffer(9)]], \
|
|
||||||
const constant int* axes [[buffer(10)]], \
|
|
||||||
const constant int* idx_shapes [[buffer(11)]], \
|
|
||||||
const constant size_t* idx_strides [[buffer(12)]], \
|
|
||||||
const constant int& idx_ndim [[buffer(13)]], \
|
|
||||||
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
|
|
||||||
Indices<IdxT, NIDX> idxs{ \
|
|
||||||
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
|
|
||||||
\
|
|
||||||
return scatter_impl<T, IdxT, Op, NIDX>( \
|
|
||||||
updates, \
|
|
||||||
out, \
|
|
||||||
upd_shape, \
|
|
||||||
upd_strides, \
|
|
||||||
upd_ndim, \
|
|
||||||
upd_size, \
|
|
||||||
out_shape, \
|
|
||||||
out_strides, \
|
|
||||||
out_ndim, \
|
|
||||||
axes, \
|
|
||||||
idxs, \
|
|
||||||
gid); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define make_scatter(n) \
|
|
||||||
make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \
|
|
||||||
make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n)
|
|
||||||
|
|
||||||
make_scatter(0) make_scatter(1) make_scatter(2) make_scatter(3) make_scatter(4)
|
|
||||||
make_scatter(5) make_scatter(6) make_scatter(7) make_scatter(8)
|
|
||||||
make_scatter(9) make_scatter(10)
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
|
||||||
// Scatter instantiations
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
|
||||||
template [[host_name("scatter" name "_" #nidx)]] [[kernel]] void \
|
|
||||||
scatter<src_t, idx_t, op_t, nidx>( \
|
|
||||||
const device src_t* updates [[buffer(1)]], \
|
|
||||||
device mlx_atomic<src_t>* out [[buffer(2)]], \
|
|
||||||
const constant int* upd_shape [[buffer(3)]], \
|
|
||||||
const constant size_t* upd_strides [[buffer(4)]], \
|
|
||||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
|
||||||
const constant size_t& upd_size [[buffer(6)]], \
|
|
||||||
const constant int* out_shape [[buffer(7)]], \
|
|
||||||
const constant size_t* out_strides [[buffer(8)]], \
|
|
||||||
const constant size_t& out_ndim [[buffer(9)]], \
|
|
||||||
const constant int* axes [[buffer(10)]], \
|
|
||||||
const constant int* idx_shapes [[buffer(11)]], \
|
|
||||||
const constant size_t* idx_strides [[buffer(12)]], \
|
|
||||||
const constant int& idx_ndim [[buffer(13)]], \
|
|
||||||
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
|
||||||
template [[host_name("scatter_1d_index" name "_" #nidx)]] [[kernel]] void \
|
|
||||||
scatter_1d_index<src_t, idx_t, op_t, nidx>( \
|
|
||||||
const device src_t* updates [[buffer(1)]], \
|
|
||||||
device mlx_atomic<src_t>* out [[buffer(2)]], \
|
|
||||||
const constant int* out_shape [[buffer(3)]], \
|
|
||||||
const constant size_t* out_strides [[buffer(4)]], \
|
|
||||||
const constant size_t& upd_size [[buffer(5)]], \
|
|
||||||
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
|
|
||||||
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
|
|
||||||
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // clang-format on
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
// Special case NINDEX=0
|
|
||||||
#define instantiate_scatter_nd0(name, type) \
|
|
||||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
|
||||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
|
||||||
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
|
||||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
|
||||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0) // clang-format on
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 3) \
|
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 4) \
|
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 5) \
|
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 6) \
|
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
|
||||||
instantiate_scatter4(name, type, ind_type, op_type, 10) // clang-format on
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_scatter2(name, type, ind_type) \
|
|
||||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
|
||||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
|
||||||
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
|
||||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
|
||||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>) // clang-format on
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_scatter(name, type) \
|
|
||||||
instantiate_scatter2(#name "bool_", type, bool) \
|
|
||||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
|
||||||
instantiate_scatter2(#name "uint16", type, uint16_t) \
|
|
||||||
instantiate_scatter2(#name "uint32", type, uint32_t) \
|
|
||||||
instantiate_scatter2(#name "uint64", type, uint64_t) \
|
|
||||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
|
||||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
|
||||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
|
||||||
instantiate_scatter2(#name "int64", type, int64_t) // clang-format on
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
// TODO uint64 and int64 unsupported
|
|
||||||
instantiate_scatter_nd0(bool_, bool)
|
|
||||||
instantiate_scatter_nd0(uint8, uint8_t)
|
|
||||||
instantiate_scatter_nd0(uint16, uint16_t)
|
|
||||||
instantiate_scatter_nd0(uint32, uint32_t)
|
|
||||||
instantiate_scatter_nd0(int8, int8_t)
|
|
||||||
instantiate_scatter_nd0(int16, int16_t)
|
|
||||||
instantiate_scatter_nd0(int32, int32_t)
|
|
||||||
instantiate_scatter_nd0(float16, half)
|
|
||||||
instantiate_scatter_nd0(float32, float)
|
|
||||||
instantiate_scatter_nd0(bfloat16, bfloat16_t)
|
|
||||||
|
|
||||||
instantiate_scatter(bool_, bool)
|
|
||||||
instantiate_scatter(uint8, uint8_t)
|
|
||||||
instantiate_scatter(uint16, uint16_t)
|
|
||||||
instantiate_scatter(uint32, uint32_t)
|
|
||||||
instantiate_scatter(int8, int8_t)
|
|
||||||
instantiate_scatter(int16, int16_t)
|
|
||||||
instantiate_scatter(int32, int32_t)
|
|
||||||
instantiate_scatter(float16, half)
|
|
||||||
instantiate_scatter(float32, float)
|
|
||||||
instantiate_scatter(bfloat16, bfloat16_t) // clang-format on
|
|
@ -1,10 +1,102 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
template <typename T, typename Op>
|
||||||
|
[[kernel]] void ternary_v(
|
||||||
|
device const bool* a,
|
||||||
|
device const T* b,
|
||||||
|
device const T* c,
|
||||||
|
device T* d,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
d[index] = Op()(a[index], b[index], c[index]);
|
||||||
|
}
|
||||||
|
|
||||||
struct Select {
|
template <typename T, typename Op>
|
||||||
template <typename T>
|
[[kernel]] void ternary_g_nd1(
|
||||||
T operator()(bool condition, T x, T y) {
|
device const bool* a,
|
||||||
return condition ? x : y;
|
device const T* b,
|
||||||
}
|
device const T* c,
|
||||||
};
|
device T* d,
|
||||||
|
constant const size_t& a_strides,
|
||||||
|
constant const size_t& b_strides,
|
||||||
|
constant const size_t& c_strides,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
auto a_idx = elem_to_loc_1(index, a_strides);
|
||||||
|
auto b_idx = elem_to_loc_1(index, b_strides);
|
||||||
|
auto c_idx = elem_to_loc_1(index, c_strides);
|
||||||
|
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
[[kernel]] void ternary_g_nd2(
|
||||||
|
device const bool* a,
|
||||||
|
device const T* b,
|
||||||
|
device const T* c,
|
||||||
|
device T* d,
|
||||||
|
constant const size_t a_strides[2],
|
||||||
|
constant const size_t b_strides[2],
|
||||||
|
constant const size_t c_strides[2],
|
||||||
|
uint2 index [[thread_position_in_grid]],
|
||||||
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||||
|
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||||
|
auto c_idx = elem_to_loc_2(index, c_strides);
|
||||||
|
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||||
|
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
[[kernel]] void ternary_g_nd3(
|
||||||
|
device const bool* a,
|
||||||
|
device const T* b,
|
||||||
|
device const T* c,
|
||||||
|
device T* d,
|
||||||
|
constant const size_t a_strides[3],
|
||||||
|
constant const size_t b_strides[3],
|
||||||
|
constant const size_t c_strides[3],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||||
|
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||||
|
auto c_idx = elem_to_loc_3(index, c_strides);
|
||||||
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
|
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op, int DIM>
|
||||||
|
[[kernel]] void ternary_g_nd(
|
||||||
|
device const bool* a,
|
||||||
|
device const T* b,
|
||||||
|
device const T* c,
|
||||||
|
device T* d,
|
||||||
|
constant const int shape[DIM],
|
||||||
|
constant const size_t a_strides[DIM],
|
||||||
|
constant const size_t b_strides[DIM],
|
||||||
|
constant const size_t c_strides[DIM],
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto idx =
|
||||||
|
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
||||||
|
size_t out_idx =
|
||||||
|
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||||
|
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
|
[[kernel]] void ternary_g(
|
||||||
|
device const bool* a,
|
||||||
|
device const T* b,
|
||||||
|
device const T* c,
|
||||||
|
device T* d,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* a_strides,
|
||||||
|
constant const size_t* b_strides,
|
||||||
|
constant const size_t* c_strides,
|
||||||
|
constant const int& ndim,
|
||||||
|
uint3 index [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
auto idx =
|
||||||
|
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
||||||
|
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||||
|
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
||||||
|
}
|
||||||
|
@ -1,115 +1,16 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include <metal_integer>
|
#include <metal_integer>
|
||||||
#include <metal_math>
|
#include <metal_math>
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/ternary.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
||||||
template <typename T, typename Op>
|
#include "mlx/backend/metal/kernels/ternary.h"
|
||||||
[[kernel]] void ternary_op_v(
|
|
||||||
device const bool* a,
|
|
||||||
device const T* b,
|
|
||||||
device const T* c,
|
|
||||||
device T* d,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
d[index] = Op()(a[index], b[index], c[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename Op>
|
|
||||||
[[kernel]] void ternary_op_g_nd1(
|
|
||||||
device const bool* a,
|
|
||||||
device const T* b,
|
|
||||||
device const T* c,
|
|
||||||
device T* d,
|
|
||||||
constant const size_t& a_strides,
|
|
||||||
constant const size_t& b_strides,
|
|
||||||
constant const size_t& c_strides,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
auto a_idx = elem_to_loc_1(index, a_strides);
|
|
||||||
auto b_idx = elem_to_loc_1(index, b_strides);
|
|
||||||
auto c_idx = elem_to_loc_1(index, c_strides);
|
|
||||||
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename Op>
|
|
||||||
[[kernel]] void ternary_op_g_nd2(
|
|
||||||
device const bool* a,
|
|
||||||
device const T* b,
|
|
||||||
device const T* c,
|
|
||||||
device T* d,
|
|
||||||
constant const size_t a_strides[2],
|
|
||||||
constant const size_t b_strides[2],
|
|
||||||
constant const size_t c_strides[2],
|
|
||||||
uint2 index [[thread_position_in_grid]],
|
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
|
||||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
|
||||||
auto c_idx = elem_to_loc_2(index, c_strides);
|
|
||||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
|
||||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename Op>
|
|
||||||
[[kernel]] void ternary_op_g_nd3(
|
|
||||||
device const bool* a,
|
|
||||||
device const T* b,
|
|
||||||
device const T* c,
|
|
||||||
device T* d,
|
|
||||||
constant const size_t a_strides[3],
|
|
||||||
constant const size_t b_strides[3],
|
|
||||||
constant const size_t c_strides[3],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
|
||||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
|
||||||
auto c_idx = elem_to_loc_3(index, c_strides);
|
|
||||||
size_t out_idx =
|
|
||||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
|
||||||
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename Op, int DIM>
|
|
||||||
[[kernel]] void ternary_op_g_nd(
|
|
||||||
device const bool* a,
|
|
||||||
device const T* b,
|
|
||||||
device const T* c,
|
|
||||||
device T* d,
|
|
||||||
constant const int shape[DIM],
|
|
||||||
constant const size_t a_strides[DIM],
|
|
||||||
constant const size_t b_strides[DIM],
|
|
||||||
constant const size_t c_strides[DIM],
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto idx =
|
|
||||||
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
|
|
||||||
size_t out_idx =
|
|
||||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
|
||||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename Op>
|
|
||||||
[[kernel]] void ternary_op_g(
|
|
||||||
device const bool* a,
|
|
||||||
device const T* b,
|
|
||||||
device const T* c,
|
|
||||||
device T* d,
|
|
||||||
constant const int* shape,
|
|
||||||
constant const size_t* a_strides,
|
|
||||||
constant const size_t* b_strides,
|
|
||||||
constant const size_t* c_strides,
|
|
||||||
constant const int& ndim,
|
|
||||||
uint3 index [[thread_position_in_grid]],
|
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
|
||||||
auto idx =
|
|
||||||
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
|
|
||||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
|
||||||
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_ternary_v(name, type, op) \
|
#define instantiate_ternary_v(name, type, op) \
|
||||||
template [[host_name(name)]] [[kernel]] void ternary_op_v<type, op>( \
|
template [[host_name("v_" name)]] [[kernel]] void ternary_v<type, op>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -117,7 +18,7 @@ template <typename T, typename Op>
|
|||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_ternary_g(name, type, op) \
|
#define instantiate_ternary_g(name, type, op) \
|
||||||
template [[host_name(name)]] [[kernel]] void ternary_op_g<type, op>( \
|
template [[host_name("g_" name)]] [[kernel]] void ternary_g<type, op>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -131,8 +32,8 @@ template <typename T, typename Op>
|
|||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
#define instantiate_ternary_g_dim(name, type, op, dims) \
|
#define instantiate_ternary_g_dim(name, type, op, dims) \
|
||||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
template [[host_name("g" #dims "_" name )]] [[kernel]] void \
|
||||||
ternary_op_g_nd<type, op, dims>( \
|
ternary_g_nd<type, op, dims>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -145,8 +46,8 @@ template <typename T, typename Op>
|
|||||||
uint3 grid_dim [[threads_per_grid]]);
|
uint3 grid_dim [[threads_per_grid]]);
|
||||||
|
|
||||||
#define instantiate_ternary_g_nd(name, type, op) \
|
#define instantiate_ternary_g_nd(name, type, op) \
|
||||||
template [[host_name(name "_1")]] [[kernel]] void \
|
template [[host_name("g1_" name)]] [[kernel]] void \
|
||||||
ternary_op_g_nd1<type, op>( \
|
ternary_g_nd1<type, op>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -155,8 +56,8 @@ template <typename T, typename Op>
|
|||||||
constant const size_t& b_strides, \
|
constant const size_t& b_strides, \
|
||||||
constant const size_t& c_strides, \
|
constant const size_t& c_strides, \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] [[kernel]] void \
|
template [[host_name("g2_" name)]] [[kernel]] void \
|
||||||
ternary_op_g_nd2<type, op>( \
|
ternary_g_nd2<type, op>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -166,8 +67,8 @@ template <typename T, typename Op>
|
|||||||
constant const size_t c_strides[2], \
|
constant const size_t c_strides[2], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] [[kernel]] void \
|
template [[host_name("g3_" name)]] [[kernel]] void \
|
||||||
ternary_op_g_nd3<type, op>( \
|
ternary_g_nd3<type, op>( \
|
||||||
device const bool* a, \
|
device const bool* a, \
|
||||||
device const type* b, \
|
device const type* b, \
|
||||||
device const type* c, \
|
device const type* c, \
|
||||||
@ -180,13 +81,11 @@ template <typename T, typename Op>
|
|||||||
instantiate_ternary_g_dim(name, type, op, 4) \
|
instantiate_ternary_g_dim(name, type, op, 4) \
|
||||||
instantiate_ternary_g_dim(name, type, op, 5)
|
instantiate_ternary_g_dim(name, type, op, 5)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_ternary_all(name, tname, type, op) \
|
#define instantiate_ternary_all(name, tname, type, op) \
|
||||||
instantiate_ternary_v("v" #name #tname, type, op) \
|
instantiate_ternary_v(#name #tname, type, op) \
|
||||||
instantiate_ternary_g("g" #name #tname, type, op) \
|
instantiate_ternary_g(#name #tname, type, op) \
|
||||||
instantiate_ternary_g_nd("g" #name #tname, type, op) // clang-format on
|
instantiate_ternary_g_nd(#name #tname, type, op)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_ternary_types(name, op) \
|
#define instantiate_ternary_types(name, op) \
|
||||||
instantiate_ternary_all(name, bool_, bool, op) \
|
instantiate_ternary_all(name, bool_, bool, op) \
|
||||||
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
instantiate_ternary_all(name, uint8, uint8_t, op) \
|
||||||
|
10
mlx/backend/metal/kernels/ternary_ops.h
Normal file
10
mlx/backend/metal/kernels/ternary_ops.h
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
struct Select {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(bool condition, T x, T y) {
|
||||||
|
return condition ? x : y;
|
||||||
|
}
|
||||||
|
};
|
@ -1,394 +1,21 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
template <typename T, typename Op>
|
||||||
|
[[kernel]] void unary_v(
|
||||||
#include <metal_integer>
|
device const T* in,
|
||||||
#include <metal_math>
|
device T* out,
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
out[index] = Op()(in[index]);
|
||||||
#include "mlx/backend/metal/kernels/erf.h"
|
|
||||||
#include "mlx/backend/metal/kernels/expm1f.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
constant float inf = metal::numeric_limits<float>::infinity();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Abs {
|
template <typename T, typename Op>
|
||||||
template <typename T>
|
[[kernel]] void unary_g(
|
||||||
T operator()(T x) {
|
device const T* in,
|
||||||
return metal::abs(x);
|
device T* out,
|
||||||
};
|
device const int* in_shape,
|
||||||
template <>
|
device const size_t* in_strides,
|
||||||
uint8_t operator()(uint8_t x) {
|
device const int& ndim,
|
||||||
return x;
|
uint index [[thread_position_in_grid]]) {
|
||||||
};
|
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
|
||||||
template <>
|
out[index] = Op()(in[idx]);
|
||||||
uint16_t operator()(uint16_t x) {
|
}
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
uint32_t operator()(uint32_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
uint64_t operator()(uint64_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
bool operator()(bool x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
|
||||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ArcCos {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::acos(x);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ArcCosh {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::acosh(x);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ArcSin {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::asin(x);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ArcSinh {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::asinh(x);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ArcTan {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::atan(x);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ArcTanh {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::atanh(x);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Ceil {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::ceil(x);
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
int8_t operator()(int8_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
int16_t operator()(int16_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
int32_t operator()(int32_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
int64_t operator()(int64_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
uint8_t operator()(uint8_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
uint16_t operator()(uint16_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
uint32_t operator()(uint32_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
uint64_t operator()(uint64_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
bool operator()(bool x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Cos {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::cos(x);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
|
||||||
return {
|
|
||||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
|
||||||
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Cosh {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::cosh(x);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
|
||||||
return {
|
|
||||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
|
||||||
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Conjugate {
|
|
||||||
complex64_t operator()(complex64_t x) {
|
|
||||||
return complex64_t{x.real, -x.imag};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Erf {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return static_cast<T>(erf(static_cast<float>(x)));
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ErfInv {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return static_cast<T>(erfinv(static_cast<float>(x)));
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Exp {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::exp(x);
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
|
||||||
auto m = metal::precise::exp(x.real);
|
|
||||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Expm1 {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return static_cast<T>(expm1f(static_cast<float>(x)));
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Floor {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::floor(x);
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
int8_t operator()(int8_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
int16_t operator()(int16_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
int32_t operator()(int32_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
int64_t operator()(int64_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
uint8_t operator()(uint8_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
uint16_t operator()(uint16_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
uint32_t operator()(uint32_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
uint64_t operator()(uint64_t x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
bool operator()(bool x) {
|
|
||||||
return x;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Log {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::log(x);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Log2 {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::log2(x);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Log10 {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::log10(x);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Log1p {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return log1p(x);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct LogicalNot {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return !x;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Negative {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return -x;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Round {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::rint(x);
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
|
||||||
return {metal::rint(x.real), metal::rint(x.imag)};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Sigmoid {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
|
||||||
return (x < 0) ? 1 - y : y;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Sign {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return (x > T(0)) - (x < T(0));
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
uint32_t operator()(uint32_t x) {
|
|
||||||
return x != 0;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Sin {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::sin(x);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
|
||||||
return {
|
|
||||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
|
||||||
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Sinh {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::sinh(x);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
|
||||||
return {
|
|
||||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
|
||||||
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Square {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return x * x;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Sqrt {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::sqrt(x);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Rsqrt {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::rsqrt(x);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Tan {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::tan(x);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
|
||||||
float tan_a = metal::precise::tan(x.real);
|
|
||||||
float tanh_b = metal::precise::tanh(x.imag);
|
|
||||||
float t1 = tan_a * tanh_b;
|
|
||||||
float denom = 1. + t1 * t1;
|
|
||||||
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Tanh {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return metal::precise::tanh(x);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
complex64_t operator()(complex64_t x) {
|
|
||||||
float tanh_a = metal::precise::tanh(x.real);
|
|
||||||
float tan_b = metal::precise::tan(x.imag);
|
|
||||||
float t1 = tanh_a * tan_b;
|
|
||||||
float denom = 1. + t1 * t1;
|
|
||||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
@ -1,35 +1,18 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/unary_ops.h"
|
||||||
#include "mlx/backend/metal/kernels/unary.h"
|
#include "mlx/backend/metal/kernels/unary.h"
|
||||||
|
|
||||||
template <typename T, typename Op>
|
|
||||||
[[kernel]] void unary_op_v(
|
|
||||||
device const T* in,
|
|
||||||
device T* out,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
out[index] = Op()(in[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename Op>
|
|
||||||
[[kernel]] void unary_op_g(
|
|
||||||
device const T* in,
|
|
||||||
device T* out,
|
|
||||||
device const int* in_shape,
|
|
||||||
device const size_t* in_strides,
|
|
||||||
device const int& ndim,
|
|
||||||
uint index [[thread_position_in_grid]]) {
|
|
||||||
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
|
|
||||||
out[index] = Op()(in[idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_unary_v(name, type, op) \
|
#define instantiate_unary_v(name, type, op) \
|
||||||
template [[host_name(name)]] [[kernel]] void unary_op_v<type, op>( \
|
template [[host_name(name)]] [[kernel]] void unary_v<type, op>( \
|
||||||
device const type* in, \
|
device const type* in, \
|
||||||
device type* out, \
|
device type* out, \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_unary_g(name, type, op) \
|
#define instantiate_unary_g(name, type, op) \
|
||||||
template [[host_name(name)]] [[kernel]] void unary_op_g<type, op>( \
|
template [[host_name(name)]] [[kernel]] void unary_g<type, op>( \
|
||||||
device const type* in, \
|
device const type* in, \
|
||||||
device type* out, \
|
device type* out, \
|
||||||
device const int* in_shape, \
|
device const int* in_shape, \
|
||||||
@ -37,18 +20,15 @@ template <typename T, typename Op>
|
|||||||
device const int& ndim, \
|
device const int& ndim, \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_unary_all(name, tname, type, op) \
|
#define instantiate_unary_all(name, tname, type, op) \
|
||||||
instantiate_unary_v("v" #name #tname, type, op) \
|
instantiate_unary_v("v" #name #tname, type, op) \
|
||||||
instantiate_unary_g("g" #name #tname, type, op) // clang-format on
|
instantiate_unary_g("g" #name #tname, type, op)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_unary_float(name, op) \
|
#define instantiate_unary_float(name, op) \
|
||||||
instantiate_unary_all(name, float16, half, op) \
|
instantiate_unary_all(name, float16, half, op) \
|
||||||
instantiate_unary_all(name, float32, float, op) \
|
instantiate_unary_all(name, float32, float, op) \
|
||||||
instantiate_unary_all(name, bfloat16, bfloat16_t, op) // clang-format on
|
instantiate_unary_all(name, bfloat16, bfloat16_t, op)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_unary_types(name, op) \
|
#define instantiate_unary_types(name, op) \
|
||||||
instantiate_unary_all(name, bool_, bool, op) \
|
instantiate_unary_all(name, bool_, bool, op) \
|
||||||
instantiate_unary_all(name, uint8, uint8_t, op) \
|
instantiate_unary_all(name, uint8, uint8_t, op) \
|
||||||
@ -59,9 +39,8 @@ template <typename T, typename Op>
|
|||||||
instantiate_unary_all(name, int16, int16_t, op) \
|
instantiate_unary_all(name, int16, int16_t, op) \
|
||||||
instantiate_unary_all(name, int32, int32_t, op) \
|
instantiate_unary_all(name, int32, int32_t, op) \
|
||||||
instantiate_unary_all(name, int64, int64_t, op) \
|
instantiate_unary_all(name, int64, int64_t, op) \
|
||||||
instantiate_unary_float(name, op) // clang-format on
|
instantiate_unary_float(name, op)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
instantiate_unary_types(abs, Abs)
|
instantiate_unary_types(abs, Abs)
|
||||||
instantiate_unary_float(arccos, ArcCos)
|
instantiate_unary_float(arccos, ArcCos)
|
||||||
instantiate_unary_float(arccosh, ArcCosh)
|
instantiate_unary_float(arccosh, ArcCosh)
|
||||||
|
392
mlx/backend/metal/kernels/unary_ops.h
Normal file
392
mlx/backend/metal/kernels/unary_ops.h
Normal file
@ -0,0 +1,392 @@
|
|||||||
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_integer>
|
||||||
|
#include <metal_math>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/erf.h"
|
||||||
|
#include "mlx/backend/metal/kernels/expm1f.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
constant float inf = metal::numeric_limits<float>::infinity();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Abs {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::abs(x);
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint8_t operator()(uint8_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint16_t operator()(uint16_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint32_t operator()(uint32_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint64_t operator()(uint64_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
bool operator()(bool x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcCos {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::acos(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcCosh {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::acosh(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcSin {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::asin(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcSinh {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::asinh(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcTan {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::atan(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ArcTanh {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::atanh(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Ceil {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::ceil(x);
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
int8_t operator()(int8_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
int16_t operator()(int16_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
int32_t operator()(int32_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
int64_t operator()(int64_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint8_t operator()(uint8_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint16_t operator()(uint16_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint32_t operator()(uint32_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint64_t operator()(uint64_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
bool operator()(bool x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Cos {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::cos(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {
|
||||||
|
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||||
|
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Cosh {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::cosh(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {
|
||||||
|
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||||
|
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Conjugate {
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return complex64_t{x.real, -x.imag};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Erf {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return static_cast<T>(erf(static_cast<float>(x)));
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ErfInv {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return static_cast<T>(erfinv(static_cast<float>(x)));
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Exp {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::exp(x);
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
auto m = metal::precise::exp(x.real);
|
||||||
|
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Expm1 {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return static_cast<T>(expm1f(static_cast<float>(x)));
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Floor {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::floor(x);
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
int8_t operator()(int8_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
int16_t operator()(int16_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
int32_t operator()(int32_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
int64_t operator()(int64_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint8_t operator()(uint8_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint16_t operator()(uint16_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint32_t operator()(uint32_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint64_t operator()(uint64_t x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
bool operator()(bool x) {
|
||||||
|
return x;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::log(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log2 {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::log2(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log10 {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::log10(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Log1p {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return log1p(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LogicalNot {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return !x;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Negative {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return -x;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Round {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::rint(x);
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {metal::rint(x.real), metal::rint(x.imag)};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sigmoid {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||||
|
return (x < 0) ? 1 - y : y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sign {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return (x > T(0)) - (x < T(0));
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
uint32_t operator()(uint32_t x) {
|
||||||
|
return x != 0;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sin {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::sin(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {
|
||||||
|
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||||
|
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sinh {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::sinh(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
return {
|
||||||
|
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||||
|
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Square {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return x * x;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Sqrt {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::sqrt(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Rsqrt {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::rsqrt(x);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Tan {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::tan(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
float tan_a = metal::precise::tan(x.real);
|
||||||
|
float tanh_b = metal::precise::tanh(x.imag);
|
||||||
|
float t1 = tan_a * tanh_b;
|
||||||
|
float denom = 1. + t1 * t1;
|
||||||
|
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Tanh {
|
||||||
|
template <typename T>
|
||||||
|
T operator()(T x) {
|
||||||
|
return metal::precise::tanh(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
complex64_t operator()(complex64_t x) {
|
||||||
|
float tanh_a = metal::precise::tanh(x.real);
|
||||||
|
float tan_b = metal::precise::tan(x.imag);
|
||||||
|
float t1 = tanh_a * tan_b;
|
||||||
|
float denom = 1. + t1 * t1;
|
||||||
|
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||||
|
};
|
||||||
|
};
|
@ -6,6 +6,8 @@
|
|||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/complex.h"
|
#include "mlx/backend/metal/kernels/complex.h"
|
||||||
|
|
||||||
|
typedef half float16_t;
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Type limits utils
|
// Type limits utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -5,24 +5,25 @@
|
|||||||
#
|
#
|
||||||
# Copyright © 2023-24 Apple Inc.
|
# Copyright © 2023-24 Apple Inc.
|
||||||
|
|
||||||
|
OUTPUT_DIR=$1
|
||||||
OUTPUT_FILE=$1
|
|
||||||
CC=$2
|
CC=$2
|
||||||
SRCDIR=$3
|
SRC_DIR=$3
|
||||||
CFLAGS=$4
|
SRC_NAME=$4
|
||||||
|
CFLAGS=$5
|
||||||
|
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_NAME}.h
|
||||||
|
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||||
|
|
||||||
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h $CFLAGS 2>/dev/null)
|
mkdir -p $OUTPUT_DIR
|
||||||
|
|
||||||
|
CONTENT=$($CC -I $SRC_DIR -DMLX_METAL_JIT -E -P $INPUT_FILE $CFLAGS 2>/dev/null)
|
||||||
|
|
||||||
cat << EOF > "$OUTPUT_FILE"
|
cat << EOF > "$OUTPUT_FILE"
|
||||||
// Copyright © 2023-24 Apple Inc.
|
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
const char* get_kernel_preamble() {
|
const char* $SRC_NAME() {
|
||||||
return R"preamble(
|
return R"preamble(
|
||||||
$CONTENT
|
$CONTENT
|
||||||
)preamble";
|
)preamble";
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
46
mlx/backend/metal/nojit_kernels.cpp
Normal file
46
mlx/backend/metal/nojit_kernels.cpp
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_unary_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array&) {
|
||||||
|
return d.get_kernel(kernel_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_binary_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out) {
|
||||||
|
return d.get_kernel(kernel_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out) {
|
||||||
|
return d.get_kernel(kernel_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_ternary_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out) {
|
||||||
|
return d.get_kernel(kernel_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_copy_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out) {
|
||||||
|
return d.get_kernel(kernel_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -4,366 +4,14 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/binary.h"
|
|
||||||
#include "mlx/backend/common/ternary.h"
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/metal/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
|
||||||
|
|
||||||
void binary_op(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs,
|
|
||||||
const std::string op) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
|
||||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
|
||||||
|
|
||||||
auto& out = outputs[0];
|
|
||||||
if (out.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to collapse contiguous dims
|
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
|
||||||
auto& strides_a = strides[0];
|
|
||||||
auto& strides_b = strides[1];
|
|
||||||
auto& strides_out = strides[2];
|
|
||||||
|
|
||||||
std::ostringstream kname;
|
|
||||||
switch (bopt) {
|
|
||||||
case BinaryOpType::ScalarScalar:
|
|
||||||
kname << "ss";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::ScalarVector:
|
|
||||||
kname << "sv";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::VectorScalar:
|
|
||||||
kname << "vs";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::VectorVector:
|
|
||||||
kname << "vv";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::General:
|
|
||||||
kname << "g";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
kname << op << type_to_name(a);
|
|
||||||
if (bopt == BinaryOpType::General &&
|
|
||||||
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
|
||||||
kname << "_" << shape.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& s = out.primitive().stream();
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
auto kernel = d.get_kernel(kname.str());
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
|
||||||
// - If a is donated it goes to the first output
|
|
||||||
// - If b is donated it goes to the first output if a was not donated
|
|
||||||
// otherwise it goes to the second output
|
|
||||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
|
||||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
|
||||||
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
|
|
||||||
compute_encoder.set_input_array(
|
|
||||||
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
|
|
||||||
compute_encoder.set_output_array(outputs[0], 2);
|
|
||||||
compute_encoder.set_output_array(outputs[1], 3);
|
|
||||||
|
|
||||||
if (bopt == BinaryOpType::General) {
|
|
||||||
auto ndim = shape.size();
|
|
||||||
if (ndim > 3) {
|
|
||||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
|
||||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
|
||||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
|
||||||
} else {
|
|
||||||
// The shape is implicit in the grid for <= 3D
|
|
||||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
|
||||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch up to 3D grid of threads
|
|
||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
|
||||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
|
||||||
size_t rest = out.size() / (dim0 * dim1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size != 1024) {
|
|
||||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
|
||||||
}
|
|
||||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
|
||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
} else {
|
|
||||||
// Launch a 1D grid of threads
|
|
||||||
size_t nthreads = out.data_size();
|
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size > nthreads) {
|
|
||||||
thread_group_size = nthreads;
|
|
||||||
}
|
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void binary_op(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
array& out,
|
|
||||||
const std::string op) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt, true);
|
|
||||||
if (out.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to collapse contiguous dims
|
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
|
||||||
auto& strides_a = strides[0];
|
|
||||||
auto& strides_b = strides[1];
|
|
||||||
auto& strides_out = strides[2];
|
|
||||||
|
|
||||||
std::ostringstream kname;
|
|
||||||
switch (bopt) {
|
|
||||||
case BinaryOpType::ScalarScalar:
|
|
||||||
kname << "ss";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::ScalarVector:
|
|
||||||
kname << "sv";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::VectorScalar:
|
|
||||||
kname << "vs";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::VectorVector:
|
|
||||||
kname << "vv";
|
|
||||||
break;
|
|
||||||
case BinaryOpType::General:
|
|
||||||
kname << "g";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
kname << op << type_to_name(a);
|
|
||||||
if (bopt == BinaryOpType::General &&
|
|
||||||
shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
|
||||||
kname << "_" << shape.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& s = out.primitive().stream();
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
auto kernel = d.get_kernel(kname.str());
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
|
||||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
|
||||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
|
||||||
compute_encoder.set_input_array(donate_a ? out : a, 0);
|
|
||||||
compute_encoder.set_input_array(donate_b ? out : b, 1);
|
|
||||||
compute_encoder.set_output_array(out, 2);
|
|
||||||
|
|
||||||
if (bopt == BinaryOpType::General) {
|
|
||||||
auto ndim = shape.size();
|
|
||||||
if (ndim > 3) {
|
|
||||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
|
||||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
|
||||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
|
||||||
} else {
|
|
||||||
// The shape is implicit in the grid for <= 3D
|
|
||||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
|
|
||||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch up to 3D grid of threads
|
|
||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
|
||||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
|
||||||
size_t rest = out.size() / (dim0 * dim1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size != 1024) {
|
|
||||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
|
||||||
}
|
|
||||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
|
||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
} else {
|
|
||||||
// Launch a 1D grid of threads
|
|
||||||
size_t nthreads =
|
|
||||||
bopt == BinaryOpType::General ? out.size() : out.data_size();
|
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size > nthreads) {
|
|
||||||
thread_group_size = nthreads;
|
|
||||||
}
|
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ternary_op(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
array& out,
|
|
||||||
const std::string op) {
|
|
||||||
assert(inputs.size() == 3);
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
auto& c = inputs[2];
|
|
||||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
|
||||||
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
|
|
||||||
|
|
||||||
if (out.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to collapse contiguous dims
|
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
|
||||||
auto& strides_a = strides[0];
|
|
||||||
auto& strides_b = strides[1];
|
|
||||||
auto& strides_c = strides[2];
|
|
||||||
auto& strides_out = strides[3];
|
|
||||||
|
|
||||||
std::ostringstream kname;
|
|
||||||
if (topt == TernaryOpType::General) {
|
|
||||||
kname << "g";
|
|
||||||
kname << op << type_to_name(b);
|
|
||||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
|
||||||
kname << "_" << shape.size();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
kname << "v";
|
|
||||||
kname << op << type_to_name(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& s = out.primitive().stream();
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
auto kernel = d.get_kernel(kname.str());
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
|
||||||
compute_encoder.set_input_array(a, 0);
|
|
||||||
compute_encoder.set_input_array(b, 1);
|
|
||||||
compute_encoder.set_input_array(c, 2);
|
|
||||||
compute_encoder.set_output_array(out, 3);
|
|
||||||
|
|
||||||
if (topt == TernaryOpType::General) {
|
|
||||||
auto ndim = shape.size();
|
|
||||||
if (ndim > 3) {
|
|
||||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
|
||||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
|
||||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
|
||||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
|
|
||||||
|
|
||||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// The shape is implicit in the grid for <= 3D
|
|
||||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
|
||||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
|
||||||
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Launch up to 3D grid of threads
|
|
||||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
|
||||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
|
||||||
size_t rest = out.size() / (dim0 * dim1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size != 1024) {
|
|
||||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
|
||||||
}
|
|
||||||
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
|
||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
} else {
|
|
||||||
// Launch a 1D grid of threads
|
|
||||||
size_t nthreads = out.data_size();
|
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size > nthreads) {
|
|
||||||
thread_group_size = nthreads;
|
|
||||||
}
|
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void unary_op(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
array& out,
|
|
||||||
const std::string op) {
|
|
||||||
auto& in = inputs[0];
|
|
||||||
bool contig = in.flags().contiguous;
|
|
||||||
if (contig) {
|
|
||||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
|
||||||
out.move_shared_buffer(in);
|
|
||||||
} else {
|
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
|
||||||
in.data_size(),
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
}
|
|
||||||
if (in.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& s = out.primitive().stream();
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
std::string tname = type_to_name(in);
|
|
||||||
std::string opt_name = contig ? "v" : "g";
|
|
||||||
auto kernel = d.get_kernel(opt_name + op + tname);
|
|
||||||
|
|
||||||
size_t nthreads = contig ? in.data_size() : in.size();
|
|
||||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
|
||||||
if (thread_group_size > nthreads) {
|
|
||||||
thread_group_size = nthreads;
|
|
||||||
}
|
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
|
||||||
compute_encoder.set_input_array(
|
|
||||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
|
||||||
compute_encoder.set_output_array(out, 1);
|
|
||||||
if (!contig) {
|
|
||||||
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
|
|
||||||
compute_encoder->setBytes(
|
|
||||||
in.strides().data(), in.ndim() * sizeof(size_t), 3);
|
|
||||||
int ndim = in.ndim();
|
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
|
||||||
}
|
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "abs");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "add");
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void arange_set_scalars(T start, T next, CommandEncoder& enc) {
|
void arange_set_scalars(T start, T next, CommandEncoder& enc) {
|
||||||
enc->setBytes(&start, sizeof(T), 0);
|
enc->setBytes(&start, sizeof(T), 0);
|
||||||
@ -431,34 +79,6 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "arccos");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "arccosh");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "arcsin");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "arcsinh");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "arctan");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "arctan2");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "arctanh");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
@ -537,26 +157,6 @@ void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
switch (op_) {
|
|
||||||
case BitwiseBinary::And:
|
|
||||||
binary_op(inputs, out, "bitwise_and");
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::Or:
|
|
||||||
binary_op(inputs, out, "bitwise_or");
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::Xor:
|
|
||||||
binary_op(inputs, out, "bitwise_xor");
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::LeftShift:
|
|
||||||
binary_op(inputs, out, "left_shift");
|
|
||||||
break;
|
|
||||||
case BitwiseBinary::RightShift:
|
|
||||||
binary_op(inputs, out, "right_shift");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
@ -588,29 +188,10 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == complex64) {
|
|
||||||
unary_op(inputs, out, "conj");
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[conjugate] conjugate must be called on complex input.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "cos");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "cosh");
|
|
||||||
}
|
|
||||||
|
|
||||||
void CustomVJP::eval_gpu(
|
void CustomVJP::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
@ -623,40 +204,6 @@ void Depends::eval_gpu(
|
|||||||
eval(inputs, outputs);
|
eval(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "div");
|
|
||||||
}
|
|
||||||
|
|
||||||
void DivMod::eval_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
binary_op(inputs, outputs, "divmod");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "rem");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "erf");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "erfinv");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "exp");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "expm1");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
auto in = inputs[0];
|
auto in = inputs[0];
|
||||||
CopyType ctype;
|
CopyType ctype;
|
||||||
@ -670,102 +217,14 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
copy_gpu(in, out, ctype);
|
copy_gpu(in, out, ctype);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "ge");
|
|
||||||
}
|
|
||||||
|
|
||||||
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "geq");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "le");
|
|
||||||
}
|
|
||||||
|
|
||||||
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "leq");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
switch (base_) {
|
|
||||||
case Base::e:
|
|
||||||
unary_op(inputs, out, "log");
|
|
||||||
break;
|
|
||||||
case Base::two:
|
|
||||||
unary_op(inputs, out, "log2");
|
|
||||||
break;
|
|
||||||
case Base::ten:
|
|
||||||
unary_op(inputs, out, "log10");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "log1p");
|
|
||||||
}
|
|
||||||
|
|
||||||
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "lnot");
|
|
||||||
}
|
|
||||||
|
|
||||||
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(
|
|
||||||
inputs,
|
|
||||||
out,
|
|
||||||
"land"); // Assume "land" is the operation identifier for logical AND
|
|
||||||
}
|
|
||||||
|
|
||||||
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(
|
|
||||||
inputs,
|
|
||||||
out,
|
|
||||||
"lor"); // Assume "lor" is the operation identifier for logical OR
|
|
||||||
}
|
|
||||||
|
|
||||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "lae");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "max");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "min");
|
|
||||||
}
|
|
||||||
|
|
||||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "floor");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Ceil::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "ceil");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "mul");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
ternary_op(inputs, out, "select");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "neg");
|
|
||||||
}
|
|
||||||
|
|
||||||
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "neq");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
// Inputs must be base input array and scalar val array
|
// Inputs must be base input array and scalar val array
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
@ -797,10 +256,6 @@ void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "pow");
|
|
||||||
}
|
|
||||||
|
|
||||||
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
@ -861,51 +316,12 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (issubdtype(in.dtype(), inexact)) {
|
|
||||||
unary_op(inputs, out, "round");
|
|
||||||
} else {
|
|
||||||
// No-op integer types
|
|
||||||
out.copy_shared_buffer(in);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "sigmoid");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "sign");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "sin");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "sinh");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Split::eval_gpu(
|
void Split::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
eval(inputs, outputs);
|
eval(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "square");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
if (recip_) {
|
|
||||||
unary_op(inputs, out, "rsqrt");
|
|
||||||
} else {
|
|
||||||
unary_op(inputs, out, "sqrt");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
@ -980,18 +396,6 @@ void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
binary_op(inputs, out, "sub");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "tan");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op(inputs, out, "tanh");
|
|
||||||
}
|
|
||||||
|
|
||||||
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
108
mlx/backend/metal/ternary.cpp
Normal file
108
mlx/backend/metal/ternary.cpp
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/ternary.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
constexpr int MAX_TERNARY_SPECIALIZED_DIMS = 5;
|
||||||
|
|
||||||
|
void ternary_op(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const std::string op) {
|
||||||
|
assert(inputs.size() == 3);
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto& c = inputs[2];
|
||||||
|
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||||
|
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
|
||||||
|
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to collapse contiguous dims
|
||||||
|
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
||||||
|
auto& strides_a = strides[0];
|
||||||
|
auto& strides_b = strides[1];
|
||||||
|
auto& strides_c = strides[2];
|
||||||
|
auto& strides_out = strides[3];
|
||||||
|
|
||||||
|
std::string kernel_name;
|
||||||
|
{
|
||||||
|
std::ostringstream kname;
|
||||||
|
if (topt == TernaryOpType::General) {
|
||||||
|
kname << "g";
|
||||||
|
if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) {
|
||||||
|
kname << shape.size();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
kname << "v";
|
||||||
|
}
|
||||||
|
kname << "_" << op << type_to_name(b);
|
||||||
|
kernel_name = kname.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
auto kernel = get_ternary_kernel(d, kernel_name, out);
|
||||||
|
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
compute_encoder.set_input_array(a, 0);
|
||||||
|
compute_encoder.set_input_array(b, 1);
|
||||||
|
compute_encoder.set_input_array(c, 2);
|
||||||
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
|
||||||
|
if (topt == TernaryOpType::General) {
|
||||||
|
auto ndim = shape.size();
|
||||||
|
if (ndim > 3) {
|
||||||
|
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||||
|
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 7);
|
||||||
|
|
||||||
|
if (ndim > MAX_TERNARY_SPECIALIZED_DIMS) {
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// The shape is implicit in the grid for <= 3D
|
||||||
|
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||||
|
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||||
|
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch up to 3D grid of threads
|
||||||
|
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
|
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
|
size_t rest = out.size() / (dim0 * dim1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size != 1024) {
|
||||||
|
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||||
|
}
|
||||||
|
MTL::Size group_dims = get_block_dims(dim0, dim1, rest);
|
||||||
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
} else {
|
||||||
|
// Launch a 1D grid of threads
|
||||||
|
size_t nthreads = out.data_size();
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size > nthreads) {
|
||||||
|
thread_group_size = nthreads;
|
||||||
|
}
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
ternary_op(inputs, out, "select");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
206
mlx/backend/metal/unary.cpp
Normal file
206
mlx/backend/metal/unary.cpp
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void unary_op(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const std::string op) {
|
||||||
|
auto& in = inputs[0];
|
||||||
|
bool contig = in.flags().contiguous;
|
||||||
|
if (contig) {
|
||||||
|
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||||
|
out.move_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
}
|
||||||
|
if (in.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out);
|
||||||
|
auto kernel = get_unary_kernel(d, kernel_name, out);
|
||||||
|
|
||||||
|
size_t nthreads = contig ? in.data_size() : in.size();
|
||||||
|
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||||
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
|
if (thread_group_size > nthreads) {
|
||||||
|
thread_group_size = nthreads;
|
||||||
|
}
|
||||||
|
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||||
|
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
compute_encoder.set_input_array(
|
||||||
|
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||||
|
compute_encoder.set_output_array(out, 1);
|
||||||
|
if (!contig) {
|
||||||
|
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
|
||||||
|
compute_encoder->setBytes(
|
||||||
|
in.strides().data(), in.ndim() * sizeof(size_t), 3);
|
||||||
|
int ndim = in.ndim();
|
||||||
|
compute_encoder->setBytes(&ndim, sizeof(int), 4);
|
||||||
|
}
|
||||||
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Abs::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "abs");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcCos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "arccos");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcCosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "arccosh");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcSin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "arcsin");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcSinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "arcsinh");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcTan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "arctan");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ArcTanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "arctanh");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Conjugate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.dtype() == complex64) {
|
||||||
|
unary_op(inputs, out, "conj");
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[conjugate] conjugate must be called on complex input.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Cos::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "cos");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "cosh");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Erf::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "erf");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ErfInv::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "erfinv");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Exp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "exp");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Expm1::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "expm1");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
switch (base_) {
|
||||||
|
case Base::e:
|
||||||
|
unary_op(inputs, out, "log");
|
||||||
|
break;
|
||||||
|
case Base::two:
|
||||||
|
unary_op(inputs, out, "log2");
|
||||||
|
break;
|
||||||
|
case Base::ten:
|
||||||
|
unary_op(inputs, out, "log10");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Log1p::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "log1p");
|
||||||
|
}
|
||||||
|
|
||||||
|
void LogicalNot::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "lnot");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "floor");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ceil::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "ceil");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Negative::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "neg");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (issubdtype(in.dtype(), inexact)) {
|
||||||
|
unary_op(inputs, out, "round");
|
||||||
|
} else {
|
||||||
|
// No-op integer types
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sigmoid::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "sigmoid");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sign::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "sign");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sin::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "sin");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sinh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "sinh");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Square::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "square");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
if (recip_) {
|
||||||
|
unary_op(inputs, out, "rsqrt");
|
||||||
|
} else {
|
||||||
|
unary_op(inputs, out, "sqrt");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "tan");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tanh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
unary_op(inputs, out, "tanh");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -33,9 +33,9 @@ NO_CPU(AsType)
|
|||||||
NO_CPU(AsStrided)
|
NO_CPU(AsStrided)
|
||||||
NO_CPU(BitwiseBinary)
|
NO_CPU(BitwiseBinary)
|
||||||
NO_CPU(BlockMaskedMM)
|
NO_CPU(BlockMaskedMM)
|
||||||
NO_CPU(BlockSparseMM)
|
|
||||||
NO_CPU(Broadcast)
|
NO_CPU(Broadcast)
|
||||||
NO_CPU(Ceil)
|
NO_CPU(Ceil)
|
||||||
|
NO_CPU(Cholesky)
|
||||||
NO_CPU(Concatenate)
|
NO_CPU(Concatenate)
|
||||||
NO_CPU(Conjugate)
|
NO_CPU(Conjugate)
|
||||||
NO_CPU(Convolution)
|
NO_CPU(Convolution)
|
||||||
@ -57,6 +57,8 @@ NO_CPU(FFT)
|
|||||||
NO_CPU(Floor)
|
NO_CPU(Floor)
|
||||||
NO_CPU(Full)
|
NO_CPU(Full)
|
||||||
NO_CPU(Gather)
|
NO_CPU(Gather)
|
||||||
|
NO_CPU(GatherMM)
|
||||||
|
NO_CPU(GatherQMM)
|
||||||
NO_CPU(Greater)
|
NO_CPU(Greater)
|
||||||
NO_CPU(GreaterEqual)
|
NO_CPU(GreaterEqual)
|
||||||
NO_CPU(Less)
|
NO_CPU(Less)
|
||||||
|
@ -41,7 +41,7 @@ if (MLX_BUILD_GGUF)
|
|||||||
gguflib STATIC
|
gguflib STATIC
|
||||||
${gguflib_SOURCE_DIR}/fp16.c
|
${gguflib_SOURCE_DIR}/fp16.c
|
||||||
${gguflib_SOURCE_DIR}/gguflib.c)
|
${gguflib_SOURCE_DIR}/gguflib.c)
|
||||||
target_link_libraries(mlx $<BUILD_INTERFACE:gguflib>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:gguflib>)
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE
|
||||||
|
@ -708,6 +708,14 @@ std::pair<std::vector<array>, std::vector<int>> Ceil::vmap(
|
|||||||
return {{ceil(inputs[0], stream())}, axes};
|
return {{ceil(inputs[0], stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Cholesky::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||||
|
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||||
|
return {{linalg::cholesky(a, upper_, stream())}, {ax}};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> Concatenate::vjp(
|
std::vector<array> Concatenate::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
|
@ -870,7 +870,7 @@ class Equal : public UnaryPrimitive {
|
|||||||
|
|
||||||
void print(std::ostream& os) override {
|
void print(std::ostream& os) override {
|
||||||
if (equal_nan_) {
|
if (equal_nan_) {
|
||||||
os << "NanEqual";
|
os << "NaNEqual";
|
||||||
} else {
|
} else {
|
||||||
os << "Equal";
|
os << "Equal";
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user